200200
201201if get_cl_header_version () >= (2 , 0 ):
202202 from pyopencl ._cl import ( # noqa
203- SVMAllocation ,
203+ SVMPointer ,
204204 SVM ,
205-
206- # FIXME
207- #enqueue_svm_migratemem,
205+ SVMAllocation ,
208206 )
209207
210208if _cl .have_gl ():
@@ -1145,25 +1143,19 @@ def memory_map_exit(self, exc_type, exc_val, exc_tb):
11451143 """
11461144
11471145 if get_cl_header_version () >= (2 , 0 ):
1148- svmallocation_old_init = SVMAllocation .__init__
1149-
1150- def svmallocation_init (self , ctx , size , alignment , flags , _interface = None ,
1151- queue = None ):
1152- """
1153- :arg ctx: a :class:`Context`
1154- :arg flags: some of :class:`svm_mem_flags`.
1155- """
1156- svmallocation_old_init (self , ctx , size , alignment , flags , queue )
1157-
1158- # mem_flags.READ_ONLY applies to kernels, not the host
1159- read_write = True
1160- _interface ["data" ] = (
1161- int (self ._ptr_as_int ()), not read_write )
1162-
1163- self .__array_interface__ = _interface
1164-
1165- if get_cl_header_version () >= (2 , 0 ):
1166- SVMAllocation .__init__ = svmallocation_init
1146+ class _ArrayInterfaceSVMAllocation (SVMAllocation ):
1147+ def __init__ (self , ctx , size , alignment , flags , _interface = None ,
1148+ queue = None ):
1149+ """
1150+ :arg ctx: a :class:`Context`
1151+ :arg flags: some of :class:`svm_mem_flags`.
1152+ """
1153+ super ().__init__ (ctx , size , alignment , flags , queue )
1154+
1155+ # mem_flags.READ_ONLY applies to kernels, not the host
1156+ read_write = True
1157+ _interface ["data" ] = (
1158+ int (self ._ptr_as_int ()), not read_write )
11671159
11681160 # }}}
11691161
@@ -1774,15 +1766,14 @@ def enqueue_copy(queue, dest, src, **kwargs):
17741766 else :
17751767 raise ValueError ("invalid dest mem object type" )
17761768
1777- elif get_cl_header_version () >= (2 , 0 ) and isinstance (dest , SVM ):
1769+ elif get_cl_header_version () >= (2 , 0 ) and isinstance (dest , SVMPointer ):
17781770 # to SVM
1779- if not isinstance (src , SVM ):
1771+ if not isinstance (src , SVMPointer ):
17801772 src = SVM (src )
17811773
17821774 is_blocking = kwargs .pop ("is_blocking" , True )
17831775 assert kwargs .pop ("src_offset" , 0 ) == 0
17841776 assert kwargs .pop ("dest_offset" , 0 ) == 0
1785- assert "byte_count" not in kwargs or kwargs .pop ("byte_count" ) == src ._size ()
17861777 return _cl ._enqueue_svm_memcpy (queue , is_blocking , dest , src , ** kwargs )
17871778
17881779 else :
@@ -1808,7 +1799,7 @@ def enqueue_copy(queue, dest, src, **kwargs):
18081799 queue , src , origin , region , dest , ** kwargs )
18091800 else :
18101801 raise ValueError ("invalid src mem object type" )
1811- elif isinstance (src , SVM ):
1802+ elif isinstance (src , SVMPointer ):
18121803 # from svm
18131804 # dest is not a SVM instance, otherwise we'd be in the branch above
18141805 is_blocking = kwargs .pop ("is_blocking" , True )
@@ -1823,14 +1814,14 @@ def enqueue_copy(queue, dest, src, **kwargs):
18231814
18241815# {{{ enqueue_fill
18251816
1826- def enqueue_fill (queue : CommandQueue , dest : Union [MemoryObjectHolder , SVM ],
1817+ def enqueue_fill (queue : CommandQueue , dest : Union [MemoryObjectHolder , SVMPointer ],
18271818 pattern : Any , size : int , * , offset : int = 0 , wait_for = None ) -> Event :
18281819 """
18291820 .. versionadded:: 2022.2
18301821 """
18311822 if isinstance (dest , MemoryObjectHolder ):
18321823 return enqueue_fill_buffer (queue , dest , pattern , offset , size , wait_for )
1833- elif isinstance (dest , SVM ):
1824+ elif isinstance (dest , SVMPointer ):
18341825 if offset :
18351826 raise NotImplementedError ("enqueue_fill with SVM does not yet support "
18361827 "offsets" )
@@ -1962,7 +1953,7 @@ def enqueue_fill_buffer(queue, mem, pattern, offset, size, wait_for=None):
19621953def enqueue_svm_memfill (queue , dest , pattern , byte_count = None , wait_for = None ):
19631954 """Fill shared virtual memory with a pattern.
19641955
1965- :arg dest: a Python buffer object, optionally wrapped in an :class:`SVM` object
1956+ :arg dest: a Python buffer object, or any implementation of :class:`SVMPointer`.
19661957 :arg pattern: a Python buffer object (e.g. a :class:`numpy.ndarray` with the
19671958 fill pattern to be used.
19681959 :arg byte_count: The size of the memory to be fill. Defaults to the
@@ -1973,8 +1964,8 @@ def enqueue_svm_memfill(queue, dest, pattern, byte_count=None, wait_for=None):
19731964 .. versionadded:: 2016.2
19741965 """
19751966
1976- if not isinstance (dest , SVM ):
1977- dest = SVM (dest )
1967+ if not isinstance (dest , SVMPointer ):
1968+ dest = SVMPointer (dest )
19781969
19791970 return _cl ._enqueue_svm_memfill (
19801971 queue , dest , pattern , byte_count = None , wait_for = None )
@@ -1983,7 +1974,7 @@ def enqueue_svm_memfill(queue, dest, pattern, byte_count=None, wait_for=None):
19831974def enqueue_svm_migratemem (queue , svms , flags , wait_for = None ):
19841975 """
19851976 :arg svms: a collection of Python buffer objects (e.g. :mod:`numpy`
1986- arrays), optionally wrapped in :class:`SVM` objects .
1977+ arrays), or any implementation of :class:`SVMPointer` .
19871978 :arg flags: a combination of :class:`mem_migration_flags`
19881979
19891980 |std-enqueue-blurb|
@@ -2069,7 +2060,8 @@ def svm_empty(ctx, flags, shape, dtype, order="C", alignment=None, queue=None):
20692060 if alignment is None :
20702061 alignment = itemsize
20712062
2072- svm_alloc = SVMAllocation (ctx , nbytes , alignment , flags , _interface = interface ,
2063+ svm_alloc = _ArrayInterfaceSVMAllocation (
2064+ ctx , nbytes , alignment , flags , _interface = interface ,
20732065 queue = queue )
20742066 return np .asarray (svm_alloc )
20752067
0 commit comments