@@ -31,6 +31,10 @@ from libc.stdint cimport uint32_t
3131import warnings
3232
3333from dpctl._backend cimport ( # noqa: E211, E402;
34+ DPCTLBuildOptionList_Append,
35+ DPCTLBuildOptionList_Create,
36+ DPCTLBuildOptionList_Delete,
37+ DPCTLBuildOptionListRef,
3438 DPCTLKernel_Copy,
3539 DPCTLKernel_Delete,
3640 DPCTLKernel_GetCompileNumSubGroups,
@@ -41,16 +45,31 @@ from dpctl._backend cimport ( # noqa: E211, E402;
4145 DPCTLKernel_GetPreferredWorkGroupSizeMultiple,
4246 DPCTLKernel_GetPrivateMemSize,
4347 DPCTLKernel_GetWorkGroupSize,
48+ DPCTLKernelBuildLog_Create,
49+ DPCTLKernelBuildLog_Delete,
50+ DPCTLKernelBuildLog_Get,
51+ DPCTLKernelBuildLogRef,
4452 DPCTLKernelBundle_Copy,
4553 DPCTLKernelBundle_CreateFromOCLSource,
4654 DPCTLKernelBundle_CreateFromSpirv,
55+ DPCTLKernelBundle_CreateFromSYCLSource,
4756 DPCTLKernelBundle_Delete,
4857 DPCTLKernelBundle_GetKernel,
58+ DPCTLKernelBundle_GetSyclKernel,
4959 DPCTLKernelBundle_HasKernel,
60+ DPCTLKernelBundle_HasSyclKernel,
61+ DPCTLKernelNameList_Append,
62+ DPCTLKernelNameList_Create,
63+ DPCTLKernelNameList_Delete,
64+ DPCTLKernelNameListRef,
5065 DPCTLSyclContextRef,
5166 DPCTLSyclDeviceRef,
5267 DPCTLSyclKernelBundleRef,
5368 DPCTLSyclKernelRef,
69+ DPCTLVirtualHeaderList_Append,
70+ DPCTLVirtualHeaderList_Create,
71+ DPCTLVirtualHeaderList_Delete,
72+ DPCTLVirtualHeaderListRef,
5473)
5574
5675__all__ = [
@@ -199,9 +218,11 @@ cdef class SyclKernelBundle:
199218 """
200219
201220 @staticmethod
202- cdef SyclKernelBundle _create(DPCTLSyclKernelBundleRef KBRef):
221+ cdef SyclKernelBundle _create(DPCTLSyclKernelBundleRef KBRef,
222+ bint is_sycl_source):
203223 cdef SyclKernelBundle ret = SyclKernelBundle.__new__ (SyclKernelBundle)
204224 ret._kernel_bundle_ref = KBRef
225+ ret._is_sycl_source = is_sycl_source
205226 return ret
206227
207228 def __dealloc__ (self ):
@@ -212,13 +233,24 @@ cdef class SyclKernelBundle:
212233
213234 cpdef SyclKernel get_sycl_kernel(self , str kernel_name):
214235 name = kernel_name.encode(" utf8" )
236+ if self ._is_sycl_source:
237+ return SyclKernel._create(
238+ DPCTLKernelBundle_GetSyclKernel(
239+ self ._kernel_bundle_ref, name
240+ ),
241+ kernel_name
242+ )
215243 return SyclKernel._create(
216244 DPCTLKernelBundle_GetKernel(self ._kernel_bundle_ref, name),
217245 kernel_name
218246 )
219247
220248 def has_sycl_kernel (self , str kernel_name ):
221249 name = kernel_name.encode(" utf8" )
250+ if self ._is_sycl_source:
251+ return DPCTLKernelBundle_HasSyclKernel(
252+ self ._kernel_bundle_ref, name
253+ )
222254 return DPCTLKernelBundle_HasKernel(self ._kernel_bundle_ref, name)
223255
224256 def addressof_ref (self ):
@@ -249,7 +281,7 @@ cdef api SyclKernelBundle SyclKernelBundle_Make(DPCTLSyclKernelBundleRef KBRef):
249281 reference.
250282 """
251283 cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
252- return SyclKernelBundle._create(copied_KBRef)
284+ return SyclKernelBundle._create(copied_KBRef, False )
253285
254286
255287cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts = " " ):
@@ -295,7 +327,7 @@ cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts=""):
295327 if KBref is NULL :
296328 raise SyclKernelBundleCompilationError()
297329
298- return SyclKernelBundle._create(KBref)
330+ return SyclKernelBundle._create(KBref, False )
299331
300332
301333cpdef create_kernel_bundle_from_spirv(
@@ -342,7 +374,121 @@ cpdef create_kernel_bundle_from_spirv(
342374 if KBref is NULL :
343375 raise SyclKernelBundleCompilationError()
344376
345- return SyclKernelBundle._create(KBref)
377+ return SyclKernelBundle._create(KBref, False )
378+
379+
380+ cpdef create_kernel_bundle_from_sycl_source(SyclQueue q,
381+ unicode source,
382+ list headers = None ,
383+ list registered_names = None ,
384+ list copts = None ):
385+ """
386+ Creates an executable SYCL kernel_bundle from SYCL source code.
387+
388+ This uses the DPC++ ``kernel_compiler`` extension to create a
389+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
390+ SYCL source code.
391+
392+ Parameters:
393+ q (:class:`dpctl.SyclQueue`)
394+ The :class:`dpctl.SyclQueue` for which the
395+ :class:`.SyclKernelBundle` is going to be built.
396+ source (unicode)
397+ SYCL source code string.
398+ headers (list)
399+ Optional list of virtual headers, where each entry in the list
400+ needs to be a tuple of header name and header content. See the
401+ documentation of the ``include_files`` property in the DPC++
402+ ``kernel_compiler`` extension for more information.
403+ Default: []
404+ registered_names (list, optional)
405+ Optional list of kernel names to register. See the
406+ documentation of the ``registered_names`` property in the DPC++
407+ ``kernel_compiler`` extension for more information.
408+ Default: []
409+ copts (list)
410+ Optional list of compilation flags that will be used
411+ when compiling the program. Default: ``""``.
412+
413+ Returns:
414+ kernel_bundle (:class:`.SyclKernelBundle`)
415+ A :class:`.SyclKernelBundle` object wrapping the
416+ ``sycl::kernel_bundle<sycl::bundle_state::executable>``
417+ returned by the C API.
418+
419+ Raises:
420+ SyclKernelBundleCompilationError
421+ If a SYCL kernel bundle could not be created. The exception
422+ message contains the build log for more details.
423+ """
424+ cdef DPCTLSyclKernelBundleRef KBref
425+ cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
426+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
427+ cdef bytes bSrc = source.encode(" utf8" )
428+ cdef const char * Src = < const char * > bSrc
429+ cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
430+ cdef bytes bOpt
431+ cdef const char * sOpt
432+ cdef bytes bName
433+ cdef const char * sName
434+ cdef bytes bContent
435+ cdef const char * sContent
436+ cdef const char * buildLogContent
437+ for opt in copts:
438+ if not isinstance (opt, unicode ):
439+ DPCTLBuildOptionList_Delete(BuildOpts)
440+ raise SyclKernelBundleCompilationError()
441+ bOpt = opt.encode(" utf8" )
442+ sOpt = < const char * > bOpt
443+ DPCTLBuildOptionList_Append(BuildOpts, sOpt)
444+
445+ cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
446+ for name in registered_names:
447+ if not isinstance (name, unicode ):
448+ DPCTLBuildOptionList_Delete(BuildOpts)
449+ DPCTLKernelNameList_Delete(KernelNames)
450+ raise SyclKernelBundleCompilationError()
451+ bName = name.encode(" utf8" )
452+ sName = < const char * > bName
453+ DPCTLKernelNameList_Append(KernelNames, sName)
454+
455+ cdef DPCTLVirtualHeaderListRef VirtualHeaders
456+ VirtualHeaders = DPCTLVirtualHeaderList_Create()
457+
458+ for name, content in headers:
459+ if not isinstance (name, unicode ) or not isinstance (content, unicode ):
460+ DPCTLBuildOptionList_Delete(BuildOpts)
461+ DPCTLKernelNameList_Delete(KernelNames)
462+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
463+ raise SyclKernelBundleCompilationError()
464+ bName = name.encode(" utf8" )
465+ sName = < const char * > bName
466+ bContent = content.encode(" utf8" )
467+ sContent = < const char * > bContent
468+ DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
469+
470+ cdef DPCTLKernelBuildLogRef BuildLog
471+ BuildLog = DPCTLKernelBuildLog_Create()
472+
473+ KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
474+ VirtualHeaders, KernelNames,
475+ BuildOpts, BuildLog)
476+
477+ if KBref is NULL :
478+ buildLogContent = DPCTLKernelBuildLog_Get(BuildLog)
479+ buildLogStr = str (buildLogContent, " utf-8" )
480+ DPCTLBuildOptionList_Delete(BuildOpts)
481+ DPCTLKernelNameList_Delete(KernelNames)
482+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
483+ DPCTLKernelBuildLog_Delete(BuildLog)
484+ raise SyclKernelBundleCompilationError(buildLogStr)
485+
486+ DPCTLBuildOptionList_Delete(BuildOpts)
487+ DPCTLKernelNameList_Delete(KernelNames)
488+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
489+ DPCTLKernelBuildLog_Delete(BuildLog)
490+
491+ return SyclKernelBundle._create(KBref, True )
346492
347493
348494cpdef create_program_from_source(SyclQueue q, str src, str copts = " " ):
0 commit comments