Skip to content

Commit 8fb5a0b

Browse files
committed
test: add up to 3d tests and benchmarks
1 parent f522872 commit 8fb5a0b

15 files changed

Lines changed: 3994 additions & 41 deletions
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/**
2+
* @license Apache-2.0
3+
*
4+
* Copyright (c) 2026 The Stdlib Authors.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
'use strict';
20+
21+
// MODULES //
22+
23+
var bench = require( '@stdlib/bench' );
24+
var discreteUniform = require( '@stdlib/random/base/discrete-uniform' ).factory;
25+
var bernoulli = require( '@stdlib/random/base/bernoulli' ).factory;
26+
var isnan = require( '@stdlib/math/base/assert/is-nan' );
27+
var pow = require( '@stdlib/math/base/special/pow' );
28+
var floor = require( '@stdlib/math/base/special/floor' );
29+
var filledarray = require( '@stdlib/array/filled' );
30+
var filledarrayBy = require( '@stdlib/array/filled-by' );
31+
var shape2strides = require( '@stdlib/ndarray/base/shape2strides' );
32+
var format = require( '@stdlib/string/format' );
33+
var pkg = require( './../package.json' ).name;
34+
var where = require( './../lib/nd.js' );
35+
36+
37+
// VARIABLES //
38+
39+
var conditionTypes = [ 'uint8' ];
40+
var types = [ 'float64' ];
41+
var order = 'column-major';
42+
43+
44+
// FUNCTIONS //
45+
46+
/**
47+
* Creates a benchmark function.
48+
*
49+
* @private
50+
* @param {PositiveInteger} len - ndarray length
51+
* @param {NonNegativeIntegerArray} shape - ndarray shape
52+
* @param {string} xtype - first input ndarray data type
53+
* @param {string} ytype - second input ndarray data type
54+
* @returns {Function} benchmark function
55+
*/
56+
function createBenchmark( len, shape, ctype, xtype, ytype, otype ) {
57+
var condition;
58+
var x;
59+
var y;
60+
var out;
61+
62+
condition = filledarrayBy( len, ctype, bernoulli( 0.5 ) );
63+
x = filledarrayBy( len, xtype, discreteUniform( -100, 100 ) );
64+
y = filledarrayBy( len, ytype, discreteUniform( -100, 100 ) );
65+
out = filledarray( 0.0, len, otype );
66+
67+
condition = {
68+
'dtype': ctype,
69+
'data': condition,
70+
'shape': shape,
71+
'strides': shape2strides( shape, order ),
72+
'offset': 0,
73+
'order': order
74+
};
75+
x = {
76+
'dtype': xtype,
77+
'data': x,
78+
'shape': shape,
79+
'strides': shape2strides( shape, order ),
80+
'offset': 0,
81+
'order': order
82+
};
83+
y = {
84+
'dtype': ytype,
85+
'data': y,
86+
'shape': shape,
87+
'strides': shape2strides( shape, order ),
88+
'offset': 0,
89+
'order': order
90+
};
91+
out = {
92+
'dtype': otype,
93+
'data': out,
94+
'shape': shape,
95+
'strides': shape2strides( shape, order ),
96+
'offset': 0,
97+
'order': order
98+
};
99+
return benchmark;
100+
101+
/**
102+
* Benchmark function.
103+
*
104+
* @private
105+
* @param {Benchmark} b - benchmark instance
106+
*/
107+
function benchmark( b ) {
108+
var i;
109+
110+
b.tic();
111+
for ( i = 0; i < b.iterations; i++ ) {
112+
where( condition, x, y, out );
113+
if ( isnan( out.data[ i%len ] ) ) {
114+
b.fail( 'should not return NaN' );
115+
}
116+
}
117+
b.toc();
118+
if ( isnan( out.data[ i%len ] ) ) {
119+
b.fail( 'should not return NaN' );
120+
}
121+
b.pass( 'benchmark finished' );
122+
b.end();
123+
}
124+
}
125+
126+
127+
// MAIN //
128+
129+
/**
130+
* Main execution sequence.
131+
*
132+
* @private
133+
*/
134+
function main() {
135+
var len;
136+
var min;
137+
var max;
138+
var sh;
139+
var t1;
140+
var t2;
141+
var t3;
142+
var tc;
143+
var f;
144+
var i;
145+
var j;
146+
var k;
147+
148+
min = 1; // 10^min
149+
max = 6; // 10^max
150+
151+
for ( k = 0; k < types.length; k++ ) {
152+
t1 = types[ k ];
153+
t2 = types[ k ];
154+
t3 = types[ k ];
155+
for ( j = 0; j < conditionTypes.length; j++ ) {
156+
tc = conditionTypes[ j ];
157+
for ( i = min; i <= max; i++ ) {
158+
len = pow( 10, i );
159+
160+
sh = [ len/2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1 ];
161+
f = createBenchmark( len, sh, tc, t1, t2, t3 );
162+
bench( format( '%s:ndims=%d,len=%d,shape=[%s],order=%s,ctype=%s,xtype=%s,ytype=%s,otype=%s', pkg, sh.length, len, sh.join( ',' ), order, tc, t1, t2, t3 ), f );
163+
164+
sh = [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, len/2 ];
165+
f = createBenchmark( len, sh, tc, t1, t2, t3 );
166+
bench( format( '%s:ndims=%d,len=%d,shape=[%s],order=%s,ctype=%s,xtype=%s,ytype=%s,otype=%s', pkg, sh.length, len, sh.join( ',' ), order, tc, t1, t2, t3 ), f );
167+
168+
len = floor( pow( len, 1.0/11.0 ) );
169+
sh = [ len, len, len, len, len, len, len, len, len, len, len ];
170+
len *= pow( len, 10 );
171+
f = createBenchmark( len, sh, tc, t1, t2, t3 );
172+
bench( format( '%s:ndims=%d,len=%d,shape=[%s],order=%s,ctype=%s,xtype=%s,ytype=%s,otype=%s', pkg, sh.length, len, sh.join( ',' ), order, tc, t1, t2, t3 ), f );
173+
}
174+
}
175+
}
176+
}
177+
178+
main();
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/**
2+
* @license Apache-2.0
3+
*
4+
* Copyright (c) 2026 The Stdlib Authors.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
'use strict';
20+
21+
// MODULES //
22+
23+
var bench = require( '@stdlib/bench' );
24+
var discreteUniform = require( '@stdlib/random/base/discrete-uniform' ).factory;
25+
var bernoulli = require( '@stdlib/random/base/bernoulli' ).factory;
26+
var isnan = require( '@stdlib/math/base/assert/is-nan' );
27+
var pow = require( '@stdlib/math/base/special/pow' );
28+
var floor = require( '@stdlib/math/base/special/floor' );
29+
var filledarray = require( '@stdlib/array/filled' );
30+
var filledarrayBy = require( '@stdlib/array/filled-by' );
31+
var shape2strides = require( '@stdlib/ndarray/base/shape2strides' );
32+
var format = require( '@stdlib/string/format' );
33+
var pkg = require( './../package.json' ).name;
34+
var where = require( './../lib/nd.js' );
35+
36+
37+
// VARIABLES //
38+
39+
var conditionTypes = [ 'uint8' ];
40+
var types = [ 'float64' ];
41+
var order = 'row-major';
42+
43+
44+
// FUNCTIONS //
45+
46+
/**
47+
* Creates a benchmark function.
48+
*
49+
* @private
50+
* @param {PositiveInteger} len - ndarray length
51+
* @param {NonNegativeIntegerArray} shape - ndarray shape
52+
* @param {string} xtype - first input ndarray data type
53+
* @param {string} ytype - second input ndarray data type
54+
* @returns {Function} benchmark function
55+
*/
56+
function createBenchmark( len, shape, ctype, xtype, ytype, otype ) {
57+
var condition;
58+
var x;
59+
var y;
60+
var out;
61+
62+
condition = filledarrayBy( len, ctype, bernoulli( 0.5 ) );
63+
x = filledarrayBy( len, xtype, discreteUniform( -100, 100 ) );
64+
y = filledarrayBy( len, ytype, discreteUniform( -100, 100 ) );
65+
out = filledarray( 0.0, len, otype );
66+
67+
condition = {
68+
'dtype': ctype,
69+
'data': condition,
70+
'shape': shape,
71+
'strides': shape2strides( shape, order ),
72+
'offset': 0,
73+
'order': order
74+
};
75+
x = {
76+
'dtype': xtype,
77+
'data': x,
78+
'shape': shape,
79+
'strides': shape2strides( shape, order ),
80+
'offset': 0,
81+
'order': order
82+
};
83+
y = {
84+
'dtype': ytype,
85+
'data': y,
86+
'shape': shape,
87+
'strides': shape2strides( shape, order ),
88+
'offset': 0,
89+
'order': order
90+
};
91+
out = {
92+
'dtype': otype,
93+
'data': out,
94+
'shape': shape,
95+
'strides': shape2strides( shape, order ),
96+
'offset': 0,
97+
'order': order
98+
};
99+
return benchmark;
100+
101+
/**
102+
* Benchmark function.
103+
*
104+
* @private
105+
* @param {Benchmark} b - benchmark instance
106+
*/
107+
function benchmark( b ) {
108+
var i;
109+
110+
b.tic();
111+
for ( i = 0; i < b.iterations; i++ ) {
112+
where( condition, x, y, out );
113+
if ( isnan( out.data[ i%len ] ) ) {
114+
b.fail( 'should not return NaN' );
115+
}
116+
}
117+
b.toc();
118+
if ( isnan( out.data[ i%len ] ) ) {
119+
b.fail( 'should not return NaN' );
120+
}
121+
b.pass( 'benchmark finished' );
122+
b.end();
123+
}
124+
}
125+
126+
127+
// MAIN //
128+
129+
/**
130+
* Main execution sequence.
131+
*
132+
* @private
133+
*/
134+
function main() {
135+
var len;
136+
var min;
137+
var max;
138+
var sh;
139+
var t1;
140+
var t2;
141+
var t3;
142+
var tc;
143+
var f;
144+
var i;
145+
var j;
146+
var k;
147+
148+
min = 1; // 10^min
149+
max = 6; // 10^max
150+
151+
for ( k = 0; k < types.length; k++ ) {
152+
t1 = types[ k ];
153+
t2 = types[ k ];
154+
t3 = types[ k ];
155+
for ( j = 0; j < conditionTypes.length; j++ ) {
156+
tc = conditionTypes[ j ];
157+
for ( i = min; i <= max; i++ ) {
158+
len = pow( 10, i );
159+
160+
sh = [ len/2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1 ];
161+
f = createBenchmark( len, sh, tc, t1, t2, t3 );
162+
bench( format( '%s:ndims=%d,len=%d,shape=[%s],order=%s,ctype=%s,xtype=%s,ytype=%s,otype=%s', pkg, sh.length, len, sh.join( ',' ), order, tc, t1, t2, t3 ), f );
163+
164+
sh = [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, len/2 ];
165+
f = createBenchmark( len, sh, tc, t1, t2, t3 );
166+
bench( format( '%s:ndims=%d,len=%d,shape=[%s],order=%s,ctype=%s,xtype=%s,ytype=%s,otype=%s', pkg, sh.length, len, sh.join( ',' ), order, tc, t1, t2, t3 ), f );
167+
168+
len = floor( pow( len, 1.0/11.0 ) );
169+
sh = [ len, len, len, len, len, len, len, len, len, len, len ];
170+
len *= pow( len, 10 );
171+
f = createBenchmark( len, sh, tc, t1, t2, t3 );
172+
bench( format( '%s:ndims=%d,len=%d,shape=[%s],order=%s,ctype=%s,xtype=%s,ytype=%s,otype=%s', pkg, sh.length, len, sh.join( ',' ), order, tc, t1, t2, t3 ), f );
173+
}
174+
}
175+
}
176+
}
177+
178+
main();

0 commit comments

Comments
 (0)