Skip to content

Commit c0c34d8

Browse files
Fix ty issues
1 parent 5039ef8 commit c0c34d8

25 files changed

Lines changed: 111 additions & 96 deletions

File tree

tilebox-datasets/tests/data/datapoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def example_datapoints(draw: DrawFn, generated_fields: bool = False, missing_fie
6262
some_time=draw(datetime_messages() | maybe_none),
6363
some_duration=draw(duration_messages() | maybe_none),
6464
some_bytes=draw(binary(min_size=1, max_size=10) | maybe_none),
65-
some_bool=draw(booleans() | maybe_none), # type: ignore[arg-type]
65+
some_bool=draw(booleans() | maybe_none),
6666
# well-known types
6767
some_identifier=draw(uuid_messages() | maybe_none),
6868
some_vec3=draw(vec3_messages() | maybe_none),

tilebox-datasets/tilebox/datasets/message_pool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from google.protobuf import descriptor_pb2, duration_pb2, timestamp_pb2
22
from google.protobuf.descriptor_pool import Default
3+
from google.protobuf.message import Message
34
from google.protobuf.message_factory import GetMessageClass, GetMessages
45

56
from tilebox.datasets.data.datasets import AnnotatedType
@@ -25,5 +26,5 @@ def register_message_types(descriptor_set: descriptor_pb2.FileDescriptorSet) ->
2526
GetMessages(descriptor_set.file, pool=Default())
2627

2728

28-
def get_message_type(type_url: str) -> type:
29+
def get_message_type(type_url: str) -> type[Message]:
2930
return GetMessageClass(Default().FindMessageTypeByName(type_url))

tilebox-datasets/tilebox/datasets/progress.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def _calc_progress_seconds(self, time: datetime) -> int:
6262

6363
def set_progress(self, time: datetime) -> None:
6464
"""Set the progress of the progress bar to the given time"""
65-
done = min(self._calc_progress_seconds(time), self._progress_bar.total)
65+
total = self._calc_progress_seconds(self._interval.end)
66+
done = min(self._calc_progress_seconds(time), total)
6667
self._progress_bar.update(done - self._progress_bar.n)
6768

6869
def set_download_info(self, datapoints: int, byte_size: int, download_time: float) -> None:
@@ -79,7 +80,8 @@ def __exit__(
7980
) -> None:
8081
try:
8182
if traceback is None:
82-
self._progress_bar.update(self._progress_bar.total - self._progress_bar.n) # set to 100%
83+
total = self._calc_progress_seconds(self._interval.end)
84+
self._progress_bar.update(total - self._progress_bar.n) # set to 100%
8385

8486
self._progress_bar.close() # mark as completed or failed
8587
except AttributeError:

tilebox-datasets/tilebox/datasets/protobuf_conversion/field_types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Sized
1+
from collections.abc import Sequence
22
from datetime import timedelta
33
from typing import Any
44
from uuid import UUID
@@ -16,9 +16,10 @@
1616
from tilebox.datasets.datasets.v1.well_known_types_pb2 import UUID as UUIDMessage # noqa: N811
1717
from tilebox.datasets.datasets.v1.well_known_types_pb2 import Geometry, LatLon, LatLonAlt, Quaternion, Vec3
1818

19-
ProtoFieldValue = Message | float | str | bool | bytes | Sized | None
19+
ScalarProtoFieldValue = Message | float | str | bool | bytes
20+
ProtoFieldValue = ScalarProtoFieldValue | Sequence[ScalarProtoFieldValue] | None
2021

21-
_FILL_VALUES_BY_DTYPE = {
22+
_FILL_VALUES_BY_DTYPE: dict[type[np.dtype[Any]], Any] = {
2223
npdtypes.Int8DType: np.int8(0),
2324
npdtypes.Int16DType: np.int16(0),
2425
npdtypes.Int32DType: np.int32(0),

tilebox-datasets/tilebox/datasets/protobuf_conversion/protobuf_xarray.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import contextlib
6-
from collections.abc import Sized
6+
from collections.abc import Sequence
77
from typing import Any, TypeVar
88

99
import numpy as np
@@ -231,10 +231,10 @@ def resize(self, buffer_size: int) -> None:
231231
elif buffer_size > len(self._data):
232232
# resize the data buffer to the new capacity, by just padding it with zeros at the end
233233
missing = buffer_size - len(self._data)
234-
self._data = np.pad(
234+
self._data = np.pad( # ty: ignore[no-matching-overload]
235235
self._data,
236236
((0, missing), (0, 0)),
237-
constant_values=self._type.fill_value, # type: ignore[arg-type]
237+
constant_values=self._type.fill_value,
238238
)
239239

240240

@@ -258,13 +258,13 @@ def __init__(
258258
self._array_dim: int | None = None
259259

260260
def __call__(self, index: int, value: ProtoFieldValue) -> None:
261-
if not isinstance(value, Sized):
261+
if not isinstance(value, Sequence):
262262
raise TypeError(f"Expected array field but got {type(value)}")
263263

264264
if self._array_dim is None or len(value) > self._array_dim:
265265
self._resize_array_dim(len(value))
266266

267-
for i, v in enumerate(value): # type: ignore[arg-type] # somehow the isinstance(value, Sized) isn't used here
267+
for i, v in enumerate(value): # somehow the isinstance(value, Sized) isn't used here
268268
self._data[index, i, :] = self._type.from_proto(v)
269269

270270
def finalize(
@@ -309,10 +309,10 @@ def _resize(self) -> None:
309309
else: # resize the data buffer to the new capacity, by just padding it with zeros at the end
310310
missing_capacity = self._capacity - self._data.shape[0]
311311
missing_array_dim = self._array_dim - self._data.shape[1]
312-
self._data = np.pad(
312+
self._data = np.pad( # ty: ignore[no-matching-overload]
313313
self._data,
314314
((0, missing_capacity), (0, missing_array_dim), (0, 0)),
315-
constant_values=self._type.fill_value, # type: ignore[arg-type]
315+
constant_values=self._type.fill_value,
316316
)
317317

318318

@@ -374,13 +374,13 @@ def _create_field_converter(field: FieldDescriptor) -> _FieldConverter:
374374
"""
375375
# special handling for enums:
376376
if field.type == FieldDescriptor.TYPE_ENUM:
377-
if field.is_repeated: # type: ignore[attr-defined]
377+
if field.is_repeated:
378378
raise NotImplementedError("Repeated enum fields are not supported")
379379

380380
return _EnumFieldConverter(field.name, enum_mapping_from_field_descriptor(field))
381381

382382
field_type = infer_field_type(field)
383-
if field.is_repeated: # type: ignore[attr-defined]
383+
if field.is_repeated:
384384
return _ArrayFieldConverter(field.name, field_type)
385385

386386
return _SimpleFieldConverter(field.name, field_type)

tilebox-datasets/tilebox/datasets/protobuf_conversion/to_protobuf.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
def to_messages( # noqa: C901, PLR0912
2222
data: IngestionData,
23-
message_type: type,
23+
message_type: type[Message],
2424
required_fields: list[str] | None = None,
2525
ignore_fields: list[str] | None = None,
2626
) -> list[Message]:
@@ -44,9 +44,9 @@ def to_messages( # noqa: C901, PLR0912
4444
# let's validate our fields, to make sure that they are all known fields for the given protobuf message
4545
# and that they are all lists of the same length
4646
field_lengths = defaultdict(list)
47-
fields: dict[str, pd.Series | np.ndarray] = {}
47+
fields: dict[str, pd.Series | np.ndarray | list[ProtoFieldValue]] = {}
4848

49-
field_names = list(map(str, data))
49+
field_names = [str(field) for field in data]
5050
if isinstance(data, xr.Dataset):
5151
# list(dataset) only returns the variables, not the coords, so for xarray we need to add the coords as well
5252
# but not all coords, we only care abou time for now
@@ -84,7 +84,7 @@ def to_messages( # noqa: C901, PLR0912
8484
else:
8585
values = convert_values_to_proto(values, field_type, filter_none=False)
8686

87-
fields[field_name] = values # type: ignore[assignment]
87+
fields[field_name] = values
8888

8989
# now convert every datapoint to a protobuf message
9090
if len(field_lengths) == 0: # early return, no actual data to convert
@@ -103,7 +103,7 @@ def marshal_messages(messages: list[Message]) -> list[bytes]:
103103

104104

105105
def columnar_to_row_based(
106-
data: dict[str, pd.Series | np.ndarray],
106+
data: dict[str, pd.Series | np.ndarray | list[ProtoFieldValue]],
107107
) -> Iterator[dict[str, Any]]:
108108
if len(data) == 0:
109109
return
@@ -126,12 +126,12 @@ def convert_values_to_proto(
126126

127127
def convert_repeated_values_to_proto(
128128
values: np.ndarray | pd.Series | list[np.ndarray], field_type: ProtobufFieldType
129-
) -> Any:
129+
) -> list[ProtoFieldValue]:
130130
if isinstance(values, np.ndarray): # it was an xarray, with potentially padded fill values at the end
131131
values = trim_trailing_fill_values(values, field_type.fill_value)
132132

133133
# since repeated fields can have different lengths between datapoints, we can filter out None values here
134-
return [convert_values_to_proto(repeated_values, field_type, filter_none=True) for repeated_values in values]
134+
return [convert_values_to_proto(repeated_values, field_type, filter_none=True) for repeated_values in values] # ty: ignore[invalid-return-type]
135135

136136

137137
def trim_trailing_fill_values(values: np.ndarray, fill_value: Any) -> list[np.ndarray]:

tilebox-datasets/tilebox/datasets/query/id_interval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,22 @@ def parse(cls, arg: IDIntervalLike, start_exclusive: bool = False, end_inclusive
5555
case IDInterval(_, _, _, _):
5656
return arg
5757
case (UUID(), UUID()):
58-
start, end = arg
58+
start: UUID = arg[0]
59+
end: UUID = arg[1]
5960
return IDInterval(
6061
start_id=start,
6162
end_id=end,
6263
start_exclusive=start_exclusive,
6364
end_inclusive=end_inclusive,
6465
)
6566
case (str(), str()):
66-
start, end = arg
67+
start: str = arg[0]
68+
end: str = arg[1]
6769
return IDInterval(
6870
start_id=UUID(start),
6971
end_id=UUID(end),
7072
start_exclusive=start_exclusive,
7173
end_inclusive=end_inclusive,
7274
)
75+
76+
raise ValueError(f"Failed to convert {arg} ({type(arg)}) to IDInterval")

tilebox-datasets/tilebox/datasets/service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def _client_info() -> ClientInfo:
271271
def _environment_info() -> str:
272272
python_version = sys.version.split(" ")[0]
273273
try:
274-
shell = str(get_ipython()) # type: ignore[name-defined]
274+
shell = str(get_ipython()) # ty: ignore[unresolved-reference]
275275
except NameError:
276276
return f"Python {python_version}" # Probably standard Python interpreter
277277

tilebox-grpc/_tilebox/grpc/error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def with_pythonic_errors(stub: Stub, async_funcs: bool = False) -> Stub:
5757
wrap_func = _wrap_rpc if not async_funcs else _async_wrap_rpc
5858
for name, rpc in stub.__dict__.items():
5959
if callable(rpc):
60-
setattr(stub, name, wrap_func(rpc)) # type: ignore[assignment]
60+
setattr(stub, name, wrap_func(rpc))
6161
return stub
6262

6363

tilebox-grpc/_tilebox/grpc/replay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def open_recording_channel(url: str, auth_token: str | None, recording: str | Pa
4040

4141

4242
def open_replay_channel(recording: str | Path, assert_request_matches: bool = True) -> Channel:
43-
return _ReplayChannel(recording, assert_request_matches) # type: ignore[return-value]
43+
return _ReplayChannel(recording, assert_request_matches) # ty: ignore[invalid-return-type] # not a subclass, but same interface so works
4444

4545

4646
class _ConcreteValue(Future):
@@ -87,7 +87,7 @@ def intercept_unary_unary(
8787
client_call_details: ClientCallDetails,
8888
request: RequestType,
8989
) -> Future:
90-
request_data = base64.b64encode(request.SerializeToString()) # type: ignore[attr-defined]
90+
request_data = base64.b64encode(request.SerializeToString()) # ty: ignore[unresolved-attribute]
9191
with self.recording.open("ab") as file:
9292
method = client_call_details.method
9393
if isinstance(method, str):
@@ -162,7 +162,7 @@ def unary_unary_call(
162162

163163
if recorded_status != StatusCode.OK.value[0]: # the recorded call was an error, so raise it again
164164
code = _STATUS_CODES[recorded_status]
165-
error = AioRpcError(code, None, None, recorded_response.decode()) # type: ignore[arg-type]
165+
error = AioRpcError(code, None, None, recorded_response.decode()) # ty: ignore[invalid-argument-type]
166166
raise error
167167

168168
return response_deserializer(base64.b64decode(recorded_response))

0 commit comments

Comments
 (0)