Skip to content

Commit db82604

Browse files
authored
[FIX] Fix return types for different numeric reductions (#368)
* fix return types for different numeric reductions, add min/max/minmax tests * update copyright years * fix "Type instantiation is excessively deep and possibly infinite" error * fix GroupByMultiple type inference
1 parent 832d776 commit db82604

10 files changed

Lines changed: 509 additions & 166 deletions

File tree

modules/cudf/src/groupby/multiple.ts

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {Column} from '../column';
1919
import {DataFrame} from '../data_frame';
2020
import {Series} from '../series';
2121
import {Table} from '../table';
22-
import {Struct} from '../types/dtypes';
22+
import {DataType, Int32, Struct} from '../types/dtypes';
2323
import {ColumnsMap, Interpolation, TypeMap} from '../types/mappings';
2424

2525
import {GroupByBase, GroupByBaseProps} from './base';
@@ -37,19 +37,16 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
3737
this.index_key = props.index_key;
3838
}
3939

40-
protected prepare_results(results: {keys: Table, cols: Column[]}) {
41-
const {keys, cols} = results;
40+
protected prepare_results<U extends {[P in keyof T]: DataType}>(results:
41+
{keys: Table, cols: Column[]}) {
42+
const {index_key, _values: {names}} = this;
43+
const {keys, cols} = results;
4244

43-
type Subset = Pick<T, R>;
44-
type Index = Struct<Subset>;
45-
type RestTypeMap = Omit<T, R>;
46-
type RestSeriesMap = ColumnsMap<RestTypeMap>;
45+
const rest_map =
46+
names.reduce((xs, key, index) => ({...xs, [key]: cols[index]}), {} as ColumnsMap<Omit<U, R>>);
4747

48-
const rest_map = this._values.names.reduce((xs, key, index) => ({...xs, [key]: cols[index]}),
49-
{} as RestSeriesMap);
50-
51-
if (this.index_key in rest_map) {
52-
throw new Error(`Groupby column name ${this.index_key} already exists`);
48+
if (index_key in rest_map) {
49+
throw new Error(`Groupby column name ${index_key} already exists`);
5350
}
5451

5552
const fields = [];
@@ -60,9 +57,16 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
6057
children.push(series);
6158
}
6259

63-
const index = Series.new<Index>({type: new Struct(fields), children: children});
60+
const index_map: any = {
61+
[index_key]: Series.new({type: new Struct(fields), children: children})._col,
62+
};
6463

65-
return new DataFrame({[this.index_key]: index._col, ...rest_map});
64+
return new DataFrame(
65+
{...index_map, ...rest_map} as ColumnsMap< //
66+
{[P in IndexKey]: Struct<{[P in keyof Pick<U, R>]: Pick<U, R>[P]}>}& //
67+
Omit<U, R>> //
68+
)
69+
.select([index_key, ...names]);
6670
}
6771

6872
/**
@@ -72,7 +76,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
7276
* device memory.
7377
*/
7478
argmax(memoryResource?: MemoryResource) {
75-
return this.prepare_results(this._cudf_groupby._argmax(this._values.asTable(), memoryResource));
79+
return this.prepare_results<{[P in keyof T]: P extends R ? T[P] : Int32}>(
80+
this._cudf_groupby._argmax(this._values.asTable(), memoryResource));
7681
}
7782

7883
/**
@@ -82,7 +87,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
8287
* device memory.
8388
*/
8489
argmin(memoryResource?: MemoryResource) {
85-
return this.prepare_results(this._cudf_groupby._argmin(this._values.asTable(), memoryResource));
90+
return this.prepare_results<{[P in keyof T]: P extends R ? T[P] : Int32}>(
91+
this._cudf_groupby._argmin(this._values.asTable(), memoryResource));
8692
}
8793

8894
/**
@@ -92,7 +98,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
9298
* device memory.
9399
*/
94100
count(memoryResource?: MemoryResource) {
95-
return this.prepare_results(this._cudf_groupby._count(this._values.asTable(), memoryResource));
101+
return this.prepare_results<{[P in keyof T]: P extends R ? T[P] : Int32}>(
102+
this._cudf_groupby._count(this._values.asTable(), memoryResource));
96103
}
97104

98105
/**
@@ -102,7 +109,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
102109
* device memory.
103110
*/
104111
max(memoryResource?: MemoryResource) {
105-
return this.prepare_results(this._cudf_groupby._max(this._values.asTable(), memoryResource));
112+
return this.prepare_results<T>(this._cudf_groupby._max(this._values.asTable(), memoryResource));
106113
}
107114

108115
/**
@@ -112,7 +119,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
112119
* device memory.
113120
*/
114121
mean(memoryResource?: MemoryResource) {
115-
return this.prepare_results(this._cudf_groupby._mean(this._values.asTable(), memoryResource));
122+
return this.prepare_results<T>(
123+
this._cudf_groupby._mean(this._values.asTable(), memoryResource));
116124
}
117125

118126
/**
@@ -122,7 +130,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
122130
* device memory.
123131
*/
124132
median(memoryResource?: MemoryResource) {
125-
return this.prepare_results(this._cudf_groupby._median(this._values.asTable(), memoryResource));
133+
return this.prepare_results<T>(
134+
this._cudf_groupby._median(this._values.asTable(), memoryResource));
126135
}
127136

128137
/**
@@ -132,7 +141,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
132141
* device memory.
133142
*/
134143
min(memoryResource?: MemoryResource) {
135-
return this.prepare_results(this._cudf_groupby._min(this._values.asTable(), memoryResource));
144+
return this.prepare_results<T>(this._cudf_groupby._min(this._values.asTable(), memoryResource));
136145
}
137146

138147
/**
@@ -143,7 +152,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
143152
* device memory.
144153
*/
145154
nth(n: number, memoryResource?: MemoryResource) {
146-
return this.prepare_results(this._cudf_groupby._nth(n, this._values.asTable(), memoryResource));
155+
return this.prepare_results<T>(
156+
this._cudf_groupby._nth(n, this._values.asTable(), memoryResource));
147157
}
148158

149159
/**
@@ -153,7 +163,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
153163
* device memory.
154164
*/
155165
nunique(memoryResource?: MemoryResource) {
156-
return this.prepare_results(
166+
return this.prepare_results<{[P in keyof T]: P extends R ? T[P] : Int32}>(
157167
this._cudf_groupby._nunique(this._values.asTable(), memoryResource));
158168
}
159169

@@ -164,7 +174,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
164174
* device memory.
165175
*/
166176
std(memoryResource?: MemoryResource) {
167-
return this.prepare_results(this._cudf_groupby._std(this._values.asTable(), memoryResource));
177+
return this.prepare_results<T>(this._cudf_groupby._std(this._values.asTable(), memoryResource));
168178
}
169179

170180
/**
@@ -174,7 +184,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
174184
* device memory.
175185
*/
176186
sum(memoryResource?: MemoryResource) {
177-
return this.prepare_results(this._cudf_groupby._sum(this._values.asTable(), memoryResource));
187+
return this.prepare_results<T>(this._cudf_groupby._sum(this._values.asTable(), memoryResource));
178188
}
179189

180190
/**
@@ -184,7 +194,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
184194
* device memory.
185195
*/
186196
var(memoryResource?: MemoryResource) {
187-
return this.prepare_results(this._cudf_groupby._var(this._values.asTable(), memoryResource));
197+
return this.prepare_results<T>(this._cudf_groupby._var(this._values.asTable(), memoryResource));
188198
}
189199

190200
/**
@@ -199,7 +209,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
199209
quantile(q = 0.5,
200210
interpolation: keyof typeof Interpolation = 'linear',
201211
memoryResource?: MemoryResource) {
202-
return this.prepare_results(this._cudf_groupby._quantile(
212+
return this.prepare_results<T>(this._cudf_groupby._quantile(
203213
q, this._values.asTable(), Interpolation[interpolation], memoryResource));
204214
}
205215
}

modules/cudf/src/series.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,7 @@ function asColumn<T extends DataType>(value: any) {
16681668
if (Array.isArray(data)) {
16691669
return fromArrow<T>(arrow.Vector.from({
16701670
highWaterMark: Infinity,
1671+
nullValues: [undefined, null, NaN],
16711672
type: value.type ?? inferType(data),
16721673
// Slice `offset` from the Array before converting so
16731674
// we don't write unnecessary values with the Arrow builders.

modules/cudf/src/series/bool.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021, NVIDIA CORPORATION.
1+
// Copyright (c) 2021-2022, NVIDIA CORPORATION.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -142,17 +142,17 @@ export class Bool8Series extends NumericSeries<Bool8> {
142142

143143
/** @inheritdoc */
144144
min(skipNulls = true, memoryResource?: MemoryResource) {
145-
return super.min(skipNulls, memoryResource) as number;
145+
return super.min(skipNulls, memoryResource) as boolean;
146146
}
147147

148148
/** @inheritdoc */
149149
max(skipNulls = true, memoryResource?: MemoryResource) {
150-
return super.max(skipNulls, memoryResource) as number;
150+
return super.max(skipNulls, memoryResource) as boolean;
151151
}
152152

153153
/** @inheritdoc */
154154
minmax(skipNulls = true, memoryResource?: MemoryResource) {
155-
return super.minmax(skipNulls, memoryResource) as [number, number];
155+
return super.minmax(skipNulls, memoryResource) as [boolean, boolean];
156156
}
157157

158158
/** @inheritdoc */

modules/cudf/src/series/float.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,21 @@ abstract class FloatSeries<T extends FloatingPoint> extends NumericSeries<T> {
334334
return (dropna) ? this._col.nansToNulls(memoryResource).nunique(dropna, memoryResource)
335335
: this._col.nunique(dropna, memoryResource);
336336
}
337+
338+
/** @inheritdoc */
339+
min(skipNulls = true, memoryResource?: MemoryResource) {
340+
return super.min(skipNulls, memoryResource) as number;
341+
}
342+
343+
/** @inheritdoc */
344+
max(skipNulls = true, memoryResource?: MemoryResource) {
345+
return super.max(skipNulls, memoryResource) as number;
346+
}
347+
348+
/** @inheritdoc */
349+
minmax(skipNulls = true, memoryResource?: MemoryResource) {
350+
return super.minmax(skipNulls, memoryResource) as [number, number];
351+
}
337352
}
338353

339354
/**

0 commit comments

Comments
 (0)