200200
201201if get_cl_header_version () >= (2 , 0 ):
202202 from pyopencl ._cl import ( # noqa
203- SVMAllocation ,
203+ SVMPointer ,
204204 SVM ,
205-
206- # FIXME
207- #enqueue_svm_migratemem,
205+ SVMAllocation ,
208206 )
209207
210208if _cl .have_gl ():
@@ -1144,25 +1142,19 @@ def memory_map_exit(self, exc_type, exc_val, exc_tb):
11441142 """
11451143
11461144 if get_cl_header_version () >= (2 , 0 ):
1147- svmallocation_old_init = SVMAllocation .__init__
1148-
1149- def svmallocation_init (self , ctx , size , alignment , flags , _interface = None ,
1150- queue = None ):
1151- """
1152- :arg ctx: a :class:`Context`
1153- :arg flags: some of :class:`svm_mem_flags`.
1154- """
1155- svmallocation_old_init (self , ctx , size , alignment , flags , queue )
1156-
1157- # mem_flags.READ_ONLY applies to kernels, not the host
1158- read_write = True
1159- _interface ["data" ] = (
1160- int (self ._ptr_as_int ()), not read_write )
1161-
1162- self .__array_interface__ = _interface
1163-
1164- if get_cl_header_version () >= (2 , 0 ):
1165- SVMAllocation .__init__ = svmallocation_init
1145+ class _ArrayInterfaceSVMAllocation (SVMAllocation ):
1146+ def __init__ (self , ctx , size , alignment , flags , _interface = None ,
1147+ queue = None ):
1148+ """
1149+ :arg ctx: a :class:`Context`
1150+ :arg flags: some of :class:`svm_mem_flags`.
1151+ """
1152+ super ().__init__ (ctx , size , alignment , flags , queue )
1153+
1154+ # mem_flags.READ_ONLY applies to kernels, not the host
1155+ read_write = True
1156+ _interface ["data" ] = (
1157+ int (self ._ptr_as_int ()), not read_write )
11661158
11671159 # }}}
11681160
@@ -1773,15 +1765,14 @@ def enqueue_copy(queue, dest, src, **kwargs):
17731765 else :
17741766 raise ValueError ("invalid dest mem object type" )
17751767
1776- elif get_cl_header_version () >= (2 , 0 ) and isinstance (dest , SVM ):
1768+ elif get_cl_header_version () >= (2 , 0 ) and isinstance (dest , SVMPointer ):
17771769 # to SVM
1778- if not isinstance (src , SVM ):
1770+ if not isinstance (src , SVMPointer ):
17791771 src = SVM (src )
17801772
17811773 is_blocking = kwargs .pop ("is_blocking" , True )
17821774 assert kwargs .pop ("src_offset" , 0 ) == 0
17831775 assert kwargs .pop ("dest_offset" , 0 ) == 0
1784- assert "byte_count" not in kwargs or kwargs .pop ("byte_count" ) == src ._size ()
17851776 return _cl ._enqueue_svm_memcpy (queue , is_blocking , dest , src , ** kwargs )
17861777
17871778 else :
@@ -1807,7 +1798,7 @@ def enqueue_copy(queue, dest, src, **kwargs):
18071798 queue , src , origin , region , dest , ** kwargs )
18081799 else :
18091800 raise ValueError ("invalid src mem object type" )
1810- elif isinstance (src , SVM ):
1801+ elif isinstance (src , SVMPointer ):
18111802 # from svm
18121803 # dest is not a SVM instance, otherwise we'd be in the branch above
18131804 is_blocking = kwargs .pop ("is_blocking" , True )
@@ -1822,14 +1813,14 @@ def enqueue_copy(queue, dest, src, **kwargs):
18221813
18231814# {{{ enqueue_fill
18241815
1825- def enqueue_fill (queue : CommandQueue , dest : Union [MemoryObjectHolder , SVM ],
1816+ def enqueue_fill (queue : CommandQueue , dest : Union [MemoryObjectHolder , SVMPointer ],
18261817 pattern : Any , size : int , * , offset : int = 0 , wait_for = None ) -> Event :
18271818 """
18281819 .. versionadded:: 2022.2
18291820 """
18301821 if isinstance (dest , MemoryObjectHolder ):
18311822 return enqueue_fill_buffer (queue , dest , pattern , offset , size , wait_for )
1832- elif isinstance (dest , SVM ):
1823+ elif isinstance (dest , SVMPointer ):
18331824 if offset :
18341825 raise NotImplementedError ("enqueue_fill with SVM does not yet support "
18351826 "offsets" )
@@ -1961,7 +1952,7 @@ def enqueue_fill_buffer(queue, mem, pattern, offset, size, wait_for=None):
19611952def enqueue_svm_memfill (queue , dest , pattern , byte_count = None , wait_for = None ):
19621953 """Fill shared virtual memory with a pattern.
19631954
1964- :arg dest: a Python buffer object, optionally wrapped in an :class:`SVM` object
1955+ :arg dest: a Python buffer object, or any implementation of :class:`SVMPointer`.
19651956 :arg pattern: a Python buffer object (e.g. a :class:`numpy.ndarray` with the
19661957 fill pattern to be used.
19671958 :arg byte_count: The size of the memory to be fill. Defaults to the
@@ -1972,8 +1963,8 @@ def enqueue_svm_memfill(queue, dest, pattern, byte_count=None, wait_for=None):
19721963 .. versionadded:: 2016.2
19731964 """
19741965
1975- if not isinstance (dest , SVM ):
1976- dest = SVM (dest )
1966+ if not isinstance (dest , SVMPointer ):
1967+ dest = SVMPointer (dest )
19771968
19781969 return _cl ._enqueue_svm_memfill (
19791970 queue , dest , pattern , byte_count = None , wait_for = None )
@@ -1982,7 +1973,7 @@ def enqueue_svm_memfill(queue, dest, pattern, byte_count=None, wait_for=None):
19821973def enqueue_svm_migratemem (queue , svms , flags , wait_for = None ):
19831974 """
19841975 :arg svms: a collection of Python buffer objects (e.g. :mod:`numpy`
1985- arrays), optionally wrapped in :class:`SVM` objects .
1976+ arrays), or any implementation of :class:`SVMPointer` .
19861977 :arg flags: a combination of :class:`mem_migration_flags`
19871978
19881979 |std-enqueue-blurb|
@@ -2068,7 +2059,8 @@ def svm_empty(ctx, flags, shape, dtype, order="C", alignment=None, queue=None):
20682059 if alignment is None :
20692060 alignment = itemsize
20702061
2071- svm_alloc = SVMAllocation (ctx , nbytes , alignment , flags , _interface = interface ,
2062+ svm_alloc = _ArrayInterfaceSVMAllocation (
2063+ ctx , nbytes , alignment , flags , _interface = interface ,
20722064 queue = queue )
20732065 return np .asarray (svm_alloc )
20742066
0 commit comments