@@ -19,7 +19,7 @@ import {Column} from '../column';
1919import { DataFrame } from '../data_frame' ;
2020import { Series } from '../series' ;
2121import { Table } from '../table' ;
22- import { Struct } from '../types/dtypes' ;
22+ import { DataType , Int32 , Struct } from '../types/dtypes' ;
2323import { ColumnsMap , Interpolation , TypeMap } from '../types/mappings' ;
2424
2525import { GroupByBase , GroupByBaseProps } from './base' ;
@@ -37,19 +37,16 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
3737 this . index_key = props . index_key ;
3838 }
3939
40- protected prepare_results ( results : { keys : Table , cols : Column [ ] } ) {
41- const { keys, cols} = results ;
40+ protected prepare_results < U extends { [ P in keyof T ] : DataType } > ( results :
41+ { keys : Table , cols : Column [ ] } ) {
42+ const { index_key, _values : { names} } = this ;
43+ const { keys, cols} = results ;
4244
43- type Subset = Pick < T , R > ;
44- type Index = Struct < Subset > ;
45- type RestTypeMap = Omit < T , R > ;
46- type RestSeriesMap = ColumnsMap < RestTypeMap > ;
45+ const rest_map =
46+ names . reduce ( ( xs , key , index ) => ( { ...xs , [ key ] : cols [ index ] } ) , { } as ColumnsMap < Omit < U , R > > ) ;
4747
48- const rest_map = this . _values . names . reduce ( ( xs , key , index ) => ( { ...xs , [ key ] : cols [ index ] } ) ,
49- { } as RestSeriesMap ) ;
50-
51- if ( this . index_key in rest_map ) {
52- throw new Error ( `Groupby column name ${ this . index_key } already exists` ) ;
48+ if ( index_key in rest_map ) {
49+ throw new Error ( `Groupby column name ${ index_key } already exists` ) ;
5350 }
5451
5552 const fields = [ ] ;
@@ -60,9 +57,16 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
6057 children . push ( series ) ;
6158 }
6259
63- const index = Series . new < Index > ( { type : new Struct ( fields ) , children : children } ) ;
60+ const index_map : any = {
61+ [ index_key ] : Series . new ( { type : new Struct ( fields ) , children : children } ) . _col ,
62+ } ;
6463
65- return new DataFrame ( { [ this . index_key ] : index . _col , ...rest_map } ) ;
64+ return new DataFrame (
65+ { ...index_map , ...rest_map } as ColumnsMap < //
66+ { [ P in IndexKey ] : Struct < { [ P in keyof Pick < U , R > ] : Pick < U , R > [ P ] } > } & //
67+ Omit < U , R > > //
68+ )
69+ . select ( [ index_key , ...names ] ) ;
6670 }
6771
6872 /**
@@ -72,7 +76,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
7276 * device memory.
7377 */
7478 argmax ( memoryResource ?: MemoryResource ) {
75- return this . prepare_results ( this . _cudf_groupby . _argmax ( this . _values . asTable ( ) , memoryResource ) ) ;
79+ return this . prepare_results < { [ P in keyof T ] : P extends R ? T [ P ] : Int32 } > (
80+ this . _cudf_groupby . _argmax ( this . _values . asTable ( ) , memoryResource ) ) ;
7681 }
7782
7883 /**
@@ -82,7 +87,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
8287 * device memory.
8388 */
8489 argmin ( memoryResource ?: MemoryResource ) {
85- return this . prepare_results ( this . _cudf_groupby . _argmin ( this . _values . asTable ( ) , memoryResource ) ) ;
90+ return this . prepare_results < { [ P in keyof T ] : P extends R ? T [ P ] : Int32 } > (
91+ this . _cudf_groupby . _argmin ( this . _values . asTable ( ) , memoryResource ) ) ;
8692 }
8793
8894 /**
@@ -92,7 +98,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
9298 * device memory.
9399 */
94100 count ( memoryResource ?: MemoryResource ) {
95- return this . prepare_results ( this . _cudf_groupby . _count ( this . _values . asTable ( ) , memoryResource ) ) ;
101+ return this . prepare_results < { [ P in keyof T ] : P extends R ? T [ P ] : Int32 } > (
102+ this . _cudf_groupby . _count ( this . _values . asTable ( ) , memoryResource ) ) ;
96103 }
97104
98105 /**
@@ -102,7 +109,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
102109 * device memory.
103110 */
104111 max ( memoryResource ?: MemoryResource ) {
105- return this . prepare_results ( this . _cudf_groupby . _max ( this . _values . asTable ( ) , memoryResource ) ) ;
112+ return this . prepare_results < T > ( this . _cudf_groupby . _max ( this . _values . asTable ( ) , memoryResource ) ) ;
106113 }
107114
108115 /**
@@ -112,7 +119,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
112119 * device memory.
113120 */
114121 mean ( memoryResource ?: MemoryResource ) {
115- return this . prepare_results ( this . _cudf_groupby . _mean ( this . _values . asTable ( ) , memoryResource ) ) ;
122+ return this . prepare_results < T > (
123+ this . _cudf_groupby . _mean ( this . _values . asTable ( ) , memoryResource ) ) ;
116124 }
117125
118126 /**
@@ -122,7 +130,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
122130 * device memory.
123131 */
124132 median ( memoryResource ?: MemoryResource ) {
125- return this . prepare_results ( this . _cudf_groupby . _median ( this . _values . asTable ( ) , memoryResource ) ) ;
133+ return this . prepare_results < T > (
134+ this . _cudf_groupby . _median ( this . _values . asTable ( ) , memoryResource ) ) ;
126135 }
127136
128137 /**
@@ -132,7 +141,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
132141 * device memory.
133142 */
134143 min ( memoryResource ?: MemoryResource ) {
135- return this . prepare_results ( this . _cudf_groupby . _min ( this . _values . asTable ( ) , memoryResource ) ) ;
144+ return this . prepare_results < T > ( this . _cudf_groupby . _min ( this . _values . asTable ( ) , memoryResource ) ) ;
136145 }
137146
138147 /**
@@ -143,7 +152,8 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
143152 * device memory.
144153 */
145154 nth ( n : number , memoryResource ?: MemoryResource ) {
146- return this . prepare_results ( this . _cudf_groupby . _nth ( n , this . _values . asTable ( ) , memoryResource ) ) ;
155+ return this . prepare_results < T > (
156+ this . _cudf_groupby . _nth ( n , this . _values . asTable ( ) , memoryResource ) ) ;
147157 }
148158
149159 /**
@@ -153,7 +163,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
153163 * device memory.
154164 */
155165 nunique ( memoryResource ?: MemoryResource ) {
156- return this . prepare_results (
166+ return this . prepare_results < { [ P in keyof T ] : P extends R ? T [ P ] : Int32 } > (
157167 this . _cudf_groupby . _nunique ( this . _values . asTable ( ) , memoryResource ) ) ;
158168 }
159169
@@ -164,7 +174,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
164174 * device memory.
165175 */
166176 std ( memoryResource ?: MemoryResource ) {
167- return this . prepare_results ( this . _cudf_groupby . _std ( this . _values . asTable ( ) , memoryResource ) ) ;
177+ return this . prepare_results < T > ( this . _cudf_groupby . _std ( this . _values . asTable ( ) , memoryResource ) ) ;
168178 }
169179
170180 /**
@@ -174,7 +184,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
174184 * device memory.
175185 */
176186 sum ( memoryResource ?: MemoryResource ) {
177- return this . prepare_results ( this . _cudf_groupby . _sum ( this . _values . asTable ( ) , memoryResource ) ) ;
187+ return this . prepare_results < T > ( this . _cudf_groupby . _sum ( this . _values . asTable ( ) , memoryResource ) ) ;
178188 }
179189
180190 /**
@@ -184,7 +194,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
184194 * device memory.
185195 */
186196 var ( memoryResource ?: MemoryResource ) {
187- return this . prepare_results ( this . _cudf_groupby . _var ( this . _values . asTable ( ) , memoryResource ) ) ;
197+ return this . prepare_results < T > ( this . _cudf_groupby . _var ( this . _values . asTable ( ) , memoryResource ) ) ;
188198 }
189199
190200 /**
@@ -199,7 +209,7 @@ export class GroupByMultiple<T extends TypeMap, R extends keyof T, IndexKey exte
199209 quantile ( q = 0.5 ,
200210 interpolation : keyof typeof Interpolation = 'linear' ,
201211 memoryResource ?: MemoryResource ) {
202- return this . prepare_results ( this . _cudf_groupby . _quantile (
212+ return this . prepare_results < T > ( this . _cudf_groupby . _quantile (
203213 q , this . _values . asTable ( ) , Interpolation [ interpolation ] , memoryResource ) ) ;
204214 }
205215}
0 commit comments