Skip to content

Commit 56efa1b

Browse files
(strict) mypy annotations for DataTable (#175)
* mypy annotations for DataTable * skip imports for strict check * typed-ast fixes
1 parent 09dd44a commit 56efa1b

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

pytools/datatable.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import IO, Any, Callable, Iterator, List, Optional, Sequence, Tuple
2+
13
from pytools import Record
24

35

@@ -9,7 +11,8 @@
911
"""
1012

1113

12-
class Row(Record):
14+
# type-ignore-reason: Record is untyped
15+
class Row(Record): # type: ignore[misc]
1316
pass
1417

1518

@@ -22,7 +25,8 @@ class DataTable:
2225
.. automethod:: join
2326
"""
2427

25-
def __init__(self, column_names, column_data=None):
28+
def __init__(self, column_names: Sequence[str],
29+
column_data: Optional[List[Any]] = None) -> None:
2630
"""Construct a new table, with the given C{column_names}.
2731
2832
:arg column_names: An indexable of column name strings.
@@ -41,26 +45,26 @@ def __init__(self, column_names, column_data=None):
4145
if len(self.column_indices) != len(self.column_names):
4246
raise RuntimeError("non-unique column names encountered")
4347

44-
def __bool__(self):
48+
def __bool__(self) -> bool:
4549
return bool(self.data)
4650

47-
def __len__(self):
51+
def __len__(self) -> int:
4852
return len(self.data)
4953

50-
def __iter__(self):
54+
def __iter__(self) -> Iterator[List[Any]]:
5155
return self.data.__iter__()
5256

53-
def __str__(self):
57+
def __str__(self) -> str:
5458
"""Return a pretty-printed version of the table."""
5559

56-
def col_width(i):
60+
def col_width(i: int) -> int:
5761
width = len(self.column_names[i])
5862
if self:
5963
width = max(width, max(len(str(row[i])) for row in self.data))
6064
return width
6165
col_widths = [col_width(i) for i in range(len(self.column_names))]
6266

63-
def format_row(row):
67+
def format_row(row: Sequence[str]) -> str:
6468
return "|".join([str(cell).ljust(col_width)
6569
for cell, col_width in zip(row, col_widths)])
6670

@@ -69,24 +73,24 @@ def format_row(row):
6973
[format_row(row) for row in self.data]
7074
return "\n".join(lines)
7175

72-
def insert(self, **kwargs):
76+
def insert(self, **kwargs: Any) -> None:
7377
values = [None for i in range(len(self.column_names))]
7478

7579
for key, val in kwargs.items():
7680
values[self.column_indices[key]] = val
7781

7882
self.insert_row(tuple(values))
7983

80-
def insert_row(self, values):
84+
def insert_row(self, values: Tuple[Any, ...]) -> None:
8185
assert isinstance(values, tuple)
8286
assert len(values) == len(self.column_names)
8387
self.data.append(values)
8488

85-
def insert_rows(self, rows):
89+
def insert_rows(self, rows: Sequence[Tuple[Any, ...]]) -> None:
8690
for row in rows:
8791
self.insert_row(row)
8892

89-
def filtered(self, **kwargs):
93+
def filtered(self, **kwargs: Any) -> "DataTable":
9094
if not kwargs:
9195
return self
9296

@@ -108,7 +112,7 @@ def filtered(self, **kwargs):
108112

109113
return DataTable(self.column_names, result_data)
110114

111-
def get(self, **kwargs):
115+
def get(self, **kwargs: Any) -> Row:
112116
filtered = self.filtered(**kwargs)
113117
if not filtered:
114118
raise RuntimeError("no matching entry for get()")
@@ -117,34 +121,35 @@ def get(self, **kwargs):
117121

118122
return Row(dict(list(zip(self.column_names, filtered.data[0]))))
119123

120-
def clear(self):
124+
def clear(self) -> None:
121125
del self.data[:]
122126

123-
def copy(self):
127+
def copy(self) -> "DataTable":
124128
"""Make a copy of the instance, but leave individual rows untouched.
125129
126130
If the rows are modified later, they will also be modified in the copy.
127131
"""
128132
return DataTable(self.column_names, self.data[:])
129133

130-
def deep_copy(self):
134+
def deep_copy(self) -> "DataTable":
131135
"""Make a copy of the instance down to the row level.
132136
133137
The copy's rows may be modified independently from the original.
134138
"""
135139
return DataTable(self.column_names, [row[:] for row in self.data])
136140

137-
def sort(self, columns, reverse=False):
141+
def sort(self, columns: Sequence[str], reverse: bool = False) -> None:
138142
col_indices = [self.column_indices[col] for col in columns]
139143

140-
def mykey(row):
144+
def mykey(row: Sequence[Any]) -> Tuple[Any, ...]:
141145
return tuple(
142146
row[col_index]
143147
for col_index in col_indices)
144148

145149
self.data.sort(reverse=reverse, key=mykey)
146150

147-
def aggregated(self, groupby, agg_column, aggregate_func):
151+
def aggregated(self, groupby: Sequence[str], agg_column: str,
152+
aggregate_func: Callable[[Sequence[Any]], Any]) -> "DataTable":
148153
gb_indices = [self.column_indices[col] for col in groupby]
149154
agg_index = self.column_indices[agg_column]
150155

@@ -153,8 +158,8 @@ def aggregated(self, groupby, agg_column, aggregate_func):
153158
result_data = []
154159

155160
# to pacify pyflakes:
156-
last_values = None
157-
agg_values = None
161+
last_values: Tuple[Any, ...] = ()
162+
agg_values: List[Row] = []
158163

159164
for row in self.data:
160165
this_values = tuple(row[i] for i in gb_indices)
@@ -175,8 +180,9 @@ def aggregated(self, groupby, agg_column, aggregate_func):
175180
[self.column_names[i] for i in gb_indices] + [agg_column],
176181
result_data)
177182

178-
def join(self, column, other_column, other_table, outer=False):
179-
"""Return a tabled joining this and the C{other_table} on C{column}.
183+
def join(self, column: str, other_column: str, other_table: "DataTable",
184+
outer: bool = False) -> "DataTable":
185+
"""Return a table joining this and the C{other_table} on C{column}.
180186
181187
The new table has the following columns:
182188
- C{column}, titled the same as in this table.
@@ -187,7 +193,7 @@ def join(self, column, other_column, other_table, outer=False):
187193
by which they are joined.
188194
""" # pylint:disable=too-many-locals,too-many-branches
189195

190-
def without(indexable, idx):
196+
def without(indexable: Tuple[str, ...], idx: int) -> Tuple[str, ...]:
191197
return indexable[:idx] + indexable[idx+1:]
192198

193199
this_key_idx = self.column_indices[column]
@@ -196,9 +202,9 @@ def without(indexable, idx):
196202
this_iter = self.data.__iter__()
197203
other_iter = other_table.data.__iter__()
198204

199-
result_columns = [self.column_names[this_key_idx]] + \
200-
without(self.column_names, this_key_idx) + \
201-
without(other_table.column_names, other_key_idx)
205+
result_columns = tuple(self.column_names[this_key_idx]) + \
206+
without(tuple(self.column_names), this_key_idx) + \
207+
without(tuple(other_table.column_names), other_key_idx)
202208

203209
result_data = []
204210

@@ -266,17 +272,17 @@ def without(indexable, idx):
266272

267273
return DataTable(result_columns, result_data)
268274

269-
def restricted(self, columns):
275+
def restricted(self, columns: Sequence[str]) -> "DataTable":
270276
col_indices = [self.column_indices[col] for col in columns]
271277

272278
return DataTable(columns,
273279
[[row[i] for i in col_indices] for row in self.data])
274280

275-
def column_data(self, column):
281+
def column_data(self, column: str) -> List[Tuple[Any, ...]]:
276282
col_index = self.column_indices[column]
277283
return [row[col_index] for row in self.data]
278284

279-
def write_csv(self, filelike, **kwargs):
285+
def write_csv(self, filelike: IO[Any], **kwargs: Any) -> None:
280286
from csv import writer
281287
csvwriter = writer(filelike, **kwargs)
282288
csvwriter.writerow(self.column_names)

run-mypy.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
#! /bin/bash
22

3+
set -ex
4+
35
mypy --show-error-codes pytools
6+
7+
mypy --strict --follow-imports=skip pytools/datatable.py

0 commit comments

Comments
 (0)