@@ -35,6 +35,7 @@ var isIntegerIndexDataType = require( './../../../base/assert/is-integer-index-d
3535var isBooleanIndexDataType = require ( './../../../base/assert/is-boolean-index-data-type' ) ;
3636var isMaskIndexDataType = require ( './../../../base/assert/is-mask-index-data-type' ) ;
3737var isDataType = require ( './../../../base/assert/is-data-type' ) ;
38+ var isString = require ( '@stdlib/assert/is-string' ) . isPrimitive ;
3839var promoteDataTypes = require ( './../../../base/promote-dtypes' ) ;
3940var defaults = require ( './../../../defaults' ) ;
4041var join = require ( '@stdlib/array/base/join' ) ;
@@ -48,6 +49,7 @@ var DEFAULT_INDEX_DTYPE = defaults.get( 'dtypes.default_index' );
4849var DEFAULT_SIGNED_INTEGER_DTYPE = defaults . get ( 'dtypes.signed_integer' ) ;
4950var DEFAULT_UNSIGNED_INTEGER_DTYPE = defaults . get ( 'dtypes.unsigned_integer' ) ;
5051var DEFAULT_REAL_FLOATING_POINT_DTYPE = defaults . get ( 'dtypes.real_floating_point' ) ;
52+ var DEFAULT_COMPLEX_FLOATING_POINT_DTYPE = defaults . get ( 'dtypes.complex_floating_point' ) ;
5153
5254// Table where, for each respective policy, the value is a function which applies the policy to an input data type...
5355var POLICY_TABLE1 = {
@@ -71,19 +73,19 @@ var POLICY_TABLE2 = {
7173 ] ,
7274 'real_floating_point' : [
7375 isRealFloatingPointDataType ,
74- DEFAULT_REAL_FLOATING_POINT_DTYPE
76+ resolveDefaultRealFloatingPoint
7577 ] ,
7678 'real_floating_point_and_generic' : [
7779 wrap ( isRealFloatingPointDataType ) ,
78- DEFAULT_REAL_FLOATING_POINT_DTYPE
80+ resolveDefaultRealFloatingPoint
7981 ] ,
8082 'complex_floating_point' : [
8183 isComplexFloatingPointDataType ,
82- defaults . get ( 'dtypes.complex_floating_point' )
84+ resolveDefaultComplexFloatingPoint
8385 ] ,
8486 'complex_floating_point_and_generic' : [
8587 wrap ( isComplexFloatingPointDataType ) ,
86- defaults . get ( 'dtypes.complex_floating_point' )
88+ resolveDefaultComplexFloatingPoint
8789 ] ,
8890
8991 // Integer policies...
@@ -169,6 +171,20 @@ var POLICY_TABLE2 = {
169171 ]
170172} ;
171173
174+ // Table mapping complex-valued floating-point data types to real-valued floating-point data types having the same precision:
175+ var COMPLEX2FLOAT = {
176+ 'complex128' : 'float64' ,
177+ 'complex64' : 'float32' ,
178+ 'complex32' : 'float16'
179+ } ;
180+
181+ // Table mapping real-valued floating-point data types to complex-valued floating-point data types having the same precision:
182+ var FLOAT2COMPLEX = {
183+ 'float64' : 'complex128' ,
184+ 'float32' : 'complex64' ,
185+ 'float16' : 'complex32'
186+ } ;
187+
172188
173189// FUNCTIONS //
174190
@@ -280,6 +296,28 @@ function accumulationPolicy( dtypes ) {
280296 return DEFAULT_REAL_FLOATING_POINT_DTYPE ;
281297}
282298
299+ /**
300+ * Resolves a default real-valued floating-point data type which preserves floating-point precision.
301+ *
302+ * @private
303+ * @param {string } dtype - input ndarray data type
304+ * @returns {string } output ndarray data type
305+ */
306+ function resolveDefaultRealFloatingPoint ( dtype ) {
307+ return COMPLEX2FLOAT [ dtype ] || DEFAULT_REAL_FLOATING_POINT_DTYPE ;
308+ }
309+
310+ /**
311+ * Resolves a default complex-valued floating-point data type which preserves floating-point precision.
312+ *
313+ * @private
314+ * @param {string } dtype - input ndarray data type
315+ * @returns {string } output ndarray data type
316+ */
317+ function resolveDefaultComplexFloatingPoint ( dtype ) {
318+ return FLOAT2COMPLEX [ dtype ] || DEFAULT_COMPLEX_FLOATING_POINT_DTYPE ;
319+ }
320+
283321
284322// MAIN //
285323
@@ -320,8 +358,11 @@ function resolve( dtypes, policy ) {
320358 // If so, we can just return the promoted data type:
321359 return dt ;
322360 }
323- // Otherwise, we need to fallback to a default data type belonging to that "kind":
324- return p [ 1 ] ;
361+ // Otherwise, we need to fallback to a default data type belonging to that "kind"...
362+ if ( isString ( p [ 1 ] ) ) {
363+ return p [ 1 ] ;
364+ }
365+ return p [ 1 ] ( dt ) ;
325366 }
326367 throw new TypeError ( format ( 'invalid argument. Second argument must be a supported data type policy. Value: `%s`.' , policy ) ) ;
327368}
0 commit comments