4646
4747
4848def _as_usm_ndarray (a , usm_type , sycl_queue ):
49+ """Converts input object to `dpctl.tensor.usm_ndarray`"""
50+
4951 if isinstance (a , dpnp_array ):
50- return a .get_array ()
52+ a = a .get_array ()
5153 return dpt .asarray (a , usm_type = usm_type , sycl_queue = sycl_queue )
5254
5355
56+ def _check_has_zero_val (a ):
57+ """Check if any element in input object is equal to zero"""
58+
59+ if dpnp .isscalar (a ):
60+ if a == 0 :
61+ return True
62+ elif hasattr (a , "any" ):
63+ if (a == 0 ).any ():
64+ return True
65+ elif any (val == 0 for val in a ):
66+ return True
67+ return False
68+
69+
70+ def _get_usm_allocations (objs , device = None , usm_type = None , sycl_queue = None ):
71+ """
72+ Get common USM allocations based on a list of input objects and an explicit
73+ device, a SYCL queue, or a USM type if specified.
74+
75+ """
76+
77+ alloc_usm_type , alloc_sycl_queue = get_usm_allocations (objs )
78+
79+ if sycl_queue is None and device is None :
80+ sycl_queue = alloc_sycl_queue
81+
82+ if usm_type is None :
83+ usm_type = alloc_usm_type or "device"
84+ return usm_type , dpnp .get_normalized_queue_device (
85+ sycl_queue = sycl_queue , device = device
86+ )
87+
88+
5489def dpnp_geomspace (
5590 start ,
5691 stop ,
@@ -62,76 +97,57 @@ def dpnp_geomspace(
6297 endpoint = True ,
6398 axis = 0 ,
6499):
65- usm_type_alloc , sycl_queue_alloc = get_usm_allocations ([start , stop ])
66-
67- if sycl_queue is None and device is None :
68- sycl_queue = sycl_queue_alloc
69- sycl_queue_normalized = dpnp .get_normalized_queue_device (
70- sycl_queue = sycl_queue , device = device
100+ usm_type , sycl_queue = _get_usm_allocations (
101+ [start , stop ], device = device , usm_type = usm_type , sycl_queue = sycl_queue
71102 )
72103
73- if usm_type is None :
74- _usm_type = "device" if usm_type_alloc is None else usm_type_alloc
75- else :
76- _usm_type = usm_type
104+ if _check_has_zero_val (start ) or _check_has_zero_val (stop ):
105+ raise ValueError ("Geometric sequence cannot include zero" )
77106
78- start = _as_usm_ndarray (start , _usm_type , sycl_queue_normalized )
79- stop = _as_usm_ndarray (stop , _usm_type , sycl_queue_normalized )
107+ start = dpnp . array (start , usm_type = usm_type , sycl_queue = sycl_queue )
108+ stop = dpnp . array (stop , usm_type = usm_type , sycl_queue = sycl_queue )
80109
81110 dt = numpy .result_type (start , stop , float (num ))
82- dt = map_dtype_to_device (dt , sycl_queue_normalized .sycl_device )
111+ dt = map_dtype_to_device (dt , sycl_queue .sycl_device )
83112 if dtype is None :
84113 dtype = dt
85114
86- if dpnp .any (start == 0 ) or dpnp .any (stop == 0 ):
87- raise ValueError ("Geometric sequence cannot include zero" )
115+ # promote both arguments to the same dtype
116+ start = start .astype (dt , copy = False )
117+ stop = stop .astype (dt , copy = False )
88118
89- out_sign = dpt .ones (
90- dpt .broadcast_arrays (start , stop )[0 ].shape ,
91- dtype = dt ,
92- usm_type = _usm_type ,
93- sycl_queue = sycl_queue_normalized ,
94- )
95- # Avoid negligible real or imaginary parts in output by rotating to
96- # positive real, calculating, then undoing rotation
97- if dpnp .issubdtype (dt , dpnp .complexfloating ):
98- all_imag = (start .real == 0.0 ) & (stop .real == 0.0 )
99- if dpnp .any (all_imag ):
100- start [all_imag ] = start [all_imag ].imag
101- stop [all_imag ] = stop [all_imag ].imag
102- out_sign [all_imag ] = 1j
103-
104- both_negative = (dpt .sign (start ) == - 1 ) & (dpt .sign (stop ) == - 1 )
105- if dpnp .any (both_negative ):
106- dpt .negative (start [both_negative ], out = start [both_negative ])
107- dpt .negative (stop [both_negative ], out = stop [both_negative ])
108- dpt .negative (out_sign [both_negative ], out = out_sign [both_negative ])
109-
110- log_start = dpt .log10 (start )
111- log_stop = dpt .log10 (stop )
119+ # Allow negative real values and ensure a consistent result for complex
120+ # (including avoiding negligible real or imaginary parts in output) by
121+ # rotating start to positive real, calculating, then undoing rotation.
122+ out_sign = dpnp .sign (start )
123+ start = start / out_sign
124+ stop = stop / out_sign
125+
126+ log_start = dpnp .log10 (start )
127+ log_stop = dpnp .log10 (stop )
112128 res = dpnp_logspace (
113129 log_start ,
114130 log_stop ,
115131 num = num ,
116132 endpoint = endpoint ,
117133 base = 10.0 ,
118- dtype = dtype ,
119- usm_type = _usm_type ,
120- sycl_queue = sycl_queue_normalized ,
121- ). get_array ()
134+ dtype = dt ,
135+ usm_type = usm_type ,
136+ sycl_queue = sycl_queue ,
137+ )
122138
139+ # Make sure the endpoints match the start and stop arguments. This is
140+ # necessary because np.exp(np.log(x)) is not necessarily equal to x.
123141 if num > 0 :
124142 res [0 ] = start
125143 if num > 1 and endpoint :
126144 res [- 1 ] = stop
127145
128- res = out_sign * res
146+ res * = out_sign
129147
130148 if axis != 0 :
131- res = dpt .moveaxis (res , 0 , axis )
132-
133- res = dpt .astype (res , dtype , copy = False )
134- return dpnp_array ._create_from_usm_ndarray (res )
149+ res = dpnp .moveaxis (res , 0 , axis )
150+ return res .astype (dtype , copy = False )
135151
136152
137153def dpnp_linspace (
@@ -264,45 +280,36 @@ def dpnp_logspace(
264280 dtype = None ,
265281 axis = 0 ,
266282):
267- if not dpnp .isscalar (base ):
268- usm_type_alloc , sycl_queue_alloc = get_usm_allocations (
269- [start , stop , base ]
270- )
271-
272- if sycl_queue is None and device is None :
273- sycl_queue = sycl_queue_alloc
274- sycl_queue = dpnp .get_normalized_queue_device (
275- sycl_queue = sycl_queue , device = device
276- )
277-
278- if usm_type is None :
279- usm_type = "device" if usm_type_alloc is None else usm_type_alloc
280- else :
281- usm_type = usm_type
283+ usm_type , sycl_queue = _get_usm_allocations (
284+ [start , stop , base ],
285+ device = device ,
286+ usm_type = usm_type ,
287+ sycl_queue = sycl_queue ,
288+ )
282289
283- start = _as_usm_ndarray (start , usm_type , sycl_queue )
284- stop = _as_usm_ndarray (stop , usm_type , sycl_queue )
285- base = _as_usm_ndarray (base , usm_type , sycl_queue )
290+ if not dpnp .isscalar (base ):
291+ base = dpnp .array (base , usm_type = usm_type , sycl_queue = sycl_queue )
292+ start = dpnp .array (start , usm_type = usm_type , sycl_queue = sycl_queue )
293+ stop = dpnp .array (stop , usm_type = usm_type , sycl_queue = sycl_queue )
286294
287- [ start , stop , base ] = dpt .broadcast_arrays (start , stop , base )
288- base = dpt .expand_dims (base , axis = axis )
295+ start , stop , base = dpnp .broadcast_arrays (start , stop , base )
296+ base = dpnp .expand_dims (base , axis = axis )
289297
290- # assume res as not a tuple, because retstep is False
298+ # assume ` res` as not a tuple, because retstep is False
291299 res = dpnp_linspace (
292300 start ,
293301 stop ,
294302 num = num ,
295- device = device ,
296303 usm_type = usm_type ,
297304 sycl_queue = sycl_queue ,
298305 endpoint = endpoint ,
299306 axis = axis ,
300- ). get_array ()
307+ )
301308
302- dpt .pow (base , res , out = res )
309+ dpnp .pow (base , res , out = res )
303310 if dtype is not None :
304- res = dpt .astype (res , dtype , copy = False )
305- return dpnp_array . _create_from_usm_ndarray ( res )
311+ res = res .astype (dtype , copy = False )
312+ return res
306313
307314
308315class dpnp_nd_grid :
0 commit comments