Skip to content

Commit 7230913

Browse files
committed
refactor: refactor main.js logic and edge cases handling
1 parent 3735e43 commit 7230913

8 files changed

Lines changed: 92 additions & 12 deletions

File tree

lib/node_modules/@stdlib/ndarray/base/where/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ Each provided ndarray should be an object with the following properties:
129129

130130
## Notes
131131

132+
- `condition` ndarray must be a `boolean` or `uint8` ndarray.
133+
- `condition`, `x`, `y`, and `out` ndarrays must have the same shape.
134+
- `x` and `y` must have the same data type.
135+
- The function **mutates** the input ndarrays shapes and strides if necessary.
132136
- For very high-dimensional ndarrays which are non-contiguous, one should consider copying the underlying data to contiguous memory before conditionally assigning elements in order to achieve better performance.
133137

134138
</section>

lib/node_modules/@stdlib/ndarray/base/where/docs/repl.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
// Using minimal ndarray-like objects...
5151
> condition = {
52-
... 'dtype': dtype,
52+
... 'dtype': 'uint8',
5353
... 'data': cbuf,
5454
... 'shape': shape,
5555
... 'strides': sc,

lib/node_modules/@stdlib/ndarray/base/where/docs/types/index.d.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import { ndarray } from '@stdlib/types/ndarray';
2929
* @param arrays - array-like object containing three input ndarrays and one output ndarray
3030
* @throws arrays must have the same number of dimensions
3131
* @throws arrays must have the same shape
32+
* @throws {Error} condition array must be a boolean or uint8 ndarray
33+
* @throws {Error} x and y ndarrays must have the same dtype
3234
*
3335
* @example
3436
* var Float64Array = require( '@stdlib/array/float64' );
@@ -57,7 +59,7 @@ import { ndarray } from '@stdlib/types/ndarray';
5759
* var oo = 0;
5860
*
5961
* // Create the input and output ndarrays:
60-
* var condition = ndarray( 'float64', cbuf, shape, sc, oc, 'row-major' );
62+
* var condition = ndarray( 'uint8', cbuf, shape, sc, oc, 'row-major' );
6163
* var x = ndarray( 'float64', xbuf, shape, sx, ox, 'row-major' );
6264
* var y = ndarray( 'float64', ybuf, shape, sy, oy, 'row-major' );
6365
* var out = ndarray( 'float64', obuf, shape, so, oo, 'row-major' );

lib/node_modules/@stdlib/ndarray/base/where/docs/types/test.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ import where = require( './index' );
2424

2525
// The function returns `undefined`...
2626
{
27-
const condition = zeros( [ 2, 2 ] );
27+
const condition = zeros( [ 2, 2 ], {
28+
'dtype': 'uint8'
29+
});
2830
const x = zeros( [ 2, 2 ] );
2931
const y = zeros( [ 2, 2 ] );
3032
const out = zeros( [ 2, 2 ] );

lib/node_modules/@stdlib/ndarray/base/where/lib/0d_accessors.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
*
9393
* // Create the ndarray-like objects:
9494
* var condition = {
95-
* 'dtype': 'uint8',
95+
* 'dtype': 'bool',
9696
* 'data': cbuf,
9797
* 'shape': shape,
9898
* 'strides': sc,

lib/node_modules/@stdlib/ndarray/base/where/lib/main.js

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ function complex2real( x ) {
223223
* @param {ArrayLikeObject<Object>} arrays - array-like object containing one condition array, two input arrays, and one output array
224224
* @throws {Error} arrays must have the same number of dimensions
225225
* @throws {Error} arrays must have the same shape
226+
* @throws {Error} condition array must be a boolean or uint8 ndarray
227+
* @throws {Error} x and y ndarrays must have the same dtype
226228
* @returns {void}
227229
*
228230
* @example
@@ -328,6 +330,14 @@ function where( arrays ) {
328330
y = ndarray2object( arrays[ 2 ] );
329331
o = ndarray2object( arrays[ 3 ] );
330332

333+
if ( !isBooleanArray( c.data ) && c.dtype !== 'bool' && c.dtype !== 'uint8') {
334+
throw new Error( format( 'invalid arguments. Condition array must be a boolean or uint8 ndarray. c.dtype=%s', c.dtype ) );
335+
}
336+
337+
if ( x.dtype !== y.dtype ) {
338+
throw new Error( format( 'invalid arguments. Input arrays must have the same data type. x.dtype=%s, y.dtype=%s', x.dtype, y.dtype ) );
339+
}
340+
331341
// Always reinterpret condition array to uint8
332342
if ( isBooleanArray( c.data ) ) {
333343
c = boolean2uint8( c );
@@ -342,6 +352,7 @@ function where( arrays ) {
342352
x = complex2real( x );
343353
y = complex2real( y );
344354
o = complex2real( o );
355+
345356
c.shape.push( 2 ); // real and imaginary components
346357
c.strides.push( 0 ); // broadcast
347358
}
@@ -363,7 +374,7 @@ function where( arrays ) {
363374
}
364375
// Determine whether we can avoid iteration altogether...
365376
if ( ndims === 0 ) {
366-
if ( hasAccessors( x, y, o ) ) {
377+
if ( hasAccessors( c, x, y, o ) ) {
367378
return ACCESSOR_WHERE[ ndims ]( c, x, y, o );
368379
}
369380
return WHERE[ ndims ]( c, x, y, o );
@@ -390,7 +401,7 @@ function where( arrays ) {
390401
}
391402
// Determine whether the ndarrays are one-dimensional and thus readily translate to one-dimensional strided arrays...
392403
if ( ndims === 1 ) {
393-
if ( hasAccessors( x, y, o ) ) {
404+
if ( hasAccessors( c, x, y, o ) ) {
394405
return ACCESSOR_WHERE[ ndims ]( c, x, y, o );
395406
}
396407
return WHERE[ ndims ]( c, x, y, o );
@@ -416,7 +427,7 @@ function where( arrays ) {
416427
x.strides = [ sx[i] ];
417428
y.strides = [ sy[i] ];
418429
o.strides = [ so[i] ];
419-
if ( hasAccessors( x, y, o ) ) {
430+
if ( hasAccessors( c, x, y, o ) ) {
420431
return ACCESSOR_WHERE[ 1 ]( c, x, y, o );
421432
}
422433
return WHERE[ 1 ]( c, x, y, o );
@@ -470,7 +481,7 @@ function where( arrays ) {
470481
x.offset = ox;
471482
y.offset = oy;
472483
o.offset = oo;
473-
if ( hasAccessors( x, y, o ) ) {
484+
if ( hasAccessors( c, x, y, o ) ) {
474485
return ACCESSOR_WHERE[ 1 ]( c, x, y, o );
475486
}
476487
return WHERE[ 1 ]( c, x, y, o );
@@ -480,7 +491,7 @@ function where( arrays ) {
480491
// Determine whether we can use simple nested loops...
481492
if ( ndims <= MAX_DIMS ) {
482493
// So long as iteration for each respective array always moves in the same direction (i.e., no mixed sign strides), we can leverage cache-optimal (i.e., normal) nested loops without resorting to blocked iteration...
483-
if ( hasAccessors( x, y, o ) ) {
494+
if ( hasAccessors( c, x, y, o ) ) {
484495
return ACCESSOR_WHERE[ ndims ]( c, x, y, o, ord === 1 );
485496
}
486497
return WHERE[ ndims ]( c, x, y, o, ord === 1 );
@@ -491,13 +502,13 @@ function where( arrays ) {
491502

492503
// Determine whether we can perform blocked iteration...
493504
if ( ndims <= MAX_DIMS ) {
494-
if ( hasAccessors( x, y, o ) ) {
505+
if ( hasAccessors( c, x, y, o ) ) {
495506
return BLOCKED_ACCESSOR_WHERE[ ndims-2 ]( c, x, y, o );
496507
}
497508
return BLOCKED_WHERE[ ndims-2 ]( c, x, y, o );
498509
}
499510
// Fall-through to linear view iteration without regard for how data is stored in memory (i.e., take the slow path)...
500-
if ( hasAccessors( x, y, o ) ) {
511+
if ( hasAccessors( c, x, y, o ) ) {
501512
return accessorwherend( c, x, y, o );
502513
}
503514
wherend( c, x, y, o );

lib/node_modules/@stdlib/ndarray/base/where/test/test.1d.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ tape( 'the function conditionally assigns elements from a 1-dimensional input nd
191191
-3.0,
192192
-4.0
193193
];
194-
condition = ndarray( 'generic', toAccessorArray( cbuf ), [ 4 ], [ 1 ], 0, 'row-major' );
194+
condition = ndarray( 'uint8', toAccessorArray( cbuf ), [ 4 ], [ 1 ], 0, 'row-major' );
195195
x = ndarray( 'generic', toAccessorArray( xbuf ), [ 4 ], [ 1 ], 0, 'row-major' );
196196
y = ndarray( 'generic', toAccessorArray( ybuf ), [ 4 ], [ 1 ], 0, 'row-major' );
197197
out = ndarray( 'generic', toAccessorArray( zeros( 4, 'generic' ) ), [ 4 ], [ 1 ], 0, 'row-major' );

lib/node_modules/@stdlib/ndarray/base/where/test/test.js

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ tape( 'main export is a function', function test( t ) {
3737
t.end();
3838
});
3939

40+
tape( 'the function throws an error if provided input ndarrays which doesnot have the same data types', function test( t ) {
41+
var condition;
42+
var values;
43+
var out;
44+
var x;
45+
var y;
46+
var i;
47+
48+
condition = ndarray( 'uint8', [ 1, 0, 0, 1 ], [ 4 ], [ 1 ], 0, 'row-major' );
49+
x = ndarray( 'float32', [ 1.0, 2.0, 3.0, 4.0 ], [ 4 ], [ 1 ], 0, 'row-major' );
50+
out = ndarray( 'generic', zeros( 8, 'generic' ), [ 4 ], [ 1 ], 0, 'row-major' );
51+
52+
values = [
53+
'float64',
54+
'complex64',
55+
'uint8',
56+
'bool'
57+
];
58+
for ( i = 0; i < values.length; i++ ) {
59+
t.throws( badValue( values[ i ] ), Error, 'throws an error when provided ' + values[ i ] );
60+
}
61+
t.end();
62+
63+
function badValue( value ) {
64+
return function badValue() {
65+
y = ndarray( value, [ -1.0, -2.0, -3.0, -4.0 ], [ 4 ], [ 1 ], 0, 'row-major' );
66+
where( [ condition, x, y, out ] );
67+
};
68+
}
69+
});
70+
4071
tape( 'the function throws an error if provided ndarrays which doesnot have the same shape', function test( t ) {
4172
var condition;
4273
var values;
@@ -68,6 +99,36 @@ tape( 'the function throws an error if provided ndarrays which doesnot have the
6899
}
69100
});
70101

102+
tape( 'the function throws an error if provided a non-boolean or uint8 condition ndarray', function test( t ) {
103+
var condition;
104+
var values;
105+
var out;
106+
var x;
107+
var y;
108+
var i;
109+
110+
x = ndarray( 'float32', [ 1.0, 2.0, 3.0, 4.0 ], [ 4 ], [ 1 ], 0, 'row-major' );
111+
y = ndarray( 'float32', [ -1.0, -2.0, -3.0, -4.0 ], [ 4 ], [ 1 ], 0, 'row-major' );
112+
out = ndarray( 'float32', zeros( 4, 'float32' ), [ 4 ], [ 1 ], 0, 'row-major' );
113+
114+
values = [
115+
'float64',
116+
'complex64',
117+
'binary'
118+
];
119+
for ( i = 0; i < values.length; i++ ) {
120+
t.throws( badValue( values[ i ] ), Error, 'throws an error when provided ' + values[ i ] );
121+
}
122+
t.end();
123+
124+
function badValue( value ) {
125+
return function badValue() {
126+
condition = ndarray( value, [ 1, 0, 0, 1 ], [ 4 ], [ 1 ], 0, 'row-major' );
127+
where( [ condition, x, y, out ] );
128+
};
129+
}
130+
});
131+
71132
tape( 'the function conditionally assigns elements from an input ndarray to an output ndarray (boolean)', function test( t ) {
72133
var condition;
73134
var expected;

0 commit comments

Comments
 (0)