Skip to content

Commit 5039ef8

Browse files
Introduce ty
1 parent 3d6c342 commit 5039ef8

5 files changed

Lines changed: 76 additions & 81 deletions

File tree

.pre-commit-config.yaml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ repos:
1414
- id: ruff-check
1515
args: [--fix, --exit-non-zero-on-fix]
1616
- id: ruff-format
17-
- repo: https://github.com/RobertCraigie/pyright-python
18-
rev: v1.1.400
17+
- repo: local
1918
hooks:
20-
- id: pyright
19+
- id: ty-check
20+
name: ty-check
21+
language: python
22+
entry: ty check
23+
pass_filenames: false
24+
args: [--python=.venv/]
25+
additional_dependencies: [ty]

pyproject.toml

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@ dev = [
2424
"pyarrow>=17.0.0",
2525
# some dev tooling
2626
"ruff>=0.11.10",
27-
# pyright 1.1.401 reports many wrong false positives, let's wait until that is fixed before upgrading
28-
"pyright>=1.1.379,<1.1.401",
2927
"pre-commit>=3.8.0",
3028
"types-protobuf>=6.30",
3129
"junitparser>=3.2.0",
30+
"ty>=0.0.11",
3231
]
3332

3433
[project.scripts]
@@ -112,33 +111,10 @@ known-first-party = ["tilebox", "_tilebox"]
112111
[tool.ruff.lint.per-file-ignores]
113112
"*/tests/*" = ["INP001", "SLF001"]
114113

115-
[tool.pyright]
114+
[tool.ty.src]
116115
exclude = [
117-
"**/.ipynb_checkpoints",
118-
"**/__pycache__",
119-
".venv",
120-
"tilebox-datasets/tests/example_dataset/*", # auto-generated code
121-
"tilebox-workflows/tests/proto/*", # auto-generated code
116+
# auto-generated code
117+
"**/*_pb2.py",
118+
"**/*_pb2.pyi",
119+
"**/*pb2_grpc.py"
122120
]
123-
124-
# ignore warnings in those files, but still type check them when used as a dependency in other files
125-
ignore = [
126-
# it's auto generated
127-
"**/datasets/v1",
128-
"**/workflows/v1",
129-
"**/tilebox/v1",
130-
"**/buf/validate",
131-
]
132-
133-
# pyright needs to have all the dependencies installed to be able to type check
134-
# we can make sure of this by telling it to use the uv venv
135-
venvPath = "."
136-
venv = ".venv"
137-
extraPaths = [
138-
"tilebox-datasets",
139-
"tilebox-grpc",
140-
"tilebox-storage",
141-
"tilebox-workflows",
142-
]
143-
144-
reportPrivateImportUsage = false

tilebox-workflows/mise.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[tools]
2+
prek = "latest"

tilebox-workflows/tilebox/workflows/task.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __new__(cls, name: str, bases: tuple[type], attrs: dict[str, Any]) -> type:
5050
return task_class
5151

5252
# Convert the class to a dataclass
53-
task_class = dataclass(task_class) # type: ignore[arg-type]
53+
task_class = dataclass(task_class)
5454

5555
# we allow overriding the execute method, but we still want to validate it
5656
# so we search for the closest base class that has an execute method and use
@@ -118,7 +118,7 @@ def _serialize(self) -> bytes:
118118

119119
@classmethod
120120
def _deserialize(cls, task_input: bytes, context: RunnerContext | None = None) -> "Task": # noqa: ARG003
121-
return cast(Task, deserialize_task(cls, task_input))
121+
return deserialize_task(cls, task_input)
122122

123123

124124
def _validate_execute_method(
@@ -201,7 +201,7 @@ def identifier() -> tuple[str, str]:
201201
class_name = task_class.__name__
202202
if hasattr(task_class, "identifier"): # if the task class has an identifier method, we use that
203203
try:
204-
name, version = task_class.identifier()
204+
name, version = task_class.identifier() # ty: ignore[call-non-callable]
205205
except TypeError as err:
206206
raise ValueError(
207207
f"Failed to invoke {class_name}.identifier(). Is it a staticmethod or classmethod without parameters?"
@@ -422,7 +422,7 @@ def serialize_task(task: Task) -> bytes:
422422
field = json.dumps(field).encode()
423423
return field
424424

425-
return json.dumps(_serialize_as_dict(task)).encode() # type: ignore[arg-type]
425+
return json.dumps(_serialize_as_dict(task)).encode()
426426

427427

428428
def _serialize_as_dict(task: Task) -> dict[str, Any]:
@@ -452,7 +452,7 @@ def _serialize_value(value: Any, base64_encode_protobuf: bool) -> Any: # noqa:
452452
return b64encode(value.SerializeToString()).decode("ascii")
453453
return value.SerializeToString()
454454
if is_dataclass(value):
455-
return _serialize_as_dict(value) # type: ignore[arg-type]
455+
return _serialize_as_dict(value)
456456
return value
457457

458458

@@ -468,11 +468,11 @@ def deserialize_task(task_cls: type, task_input: bytes) -> Task:
468468
return task_cls() # empty task
469469
if len(task_fields) == 1:
470470
# if there is only one field, we deserialize it directly
471-
field_type = _get_deserialization_field_type(task_fields[0].type) # type: ignore[arg-type]
471+
field_type = _get_deserialization_field_type(task_fields[0].type) # ty: ignore[invalid-argument-type]
472472
if hasattr(field_type, "FromString"): # protobuf message
473473
value = field_type.FromString(task_input) # type: ignore[arg-type]
474474
else:
475-
value = _deserialize_value(field_type, json.loads(task_input.decode())) # type: ignore[arg-type]
475+
value = _deserialize_value(field_type, json.loads(task_input.decode()))
476476

477477
return task_cls(**{task_fields[0].name: value})
478478

@@ -483,7 +483,7 @@ def _deserialize_dataclass(cls: type, params: dict[str, Any]) -> Task:
483483
"""Deserialize a dataclass, while allowing recursively nested dataclasses or protobuf messages."""
484484
for param in list(params):
485485
# recursively deserialize nested dataclasses
486-
field = cls.__dataclass_fields__[param]
486+
field = cls.__dataclass_fields__[param] # ty: ignore[unresolved-attribute]
487487
params[field.name] = _deserialize_value(field.type, params[field.name])
488488

489489
return cls(**params)
@@ -495,7 +495,7 @@ def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911
495495

496496
field_type = _get_deserialization_field_type(field_type)
497497
if hasattr(field_type, "FromString"):
498-
return field_type.FromString(b64decode(value))
498+
return field_type.FromString(b64decode(value)) # ty: ignore[call-non-callable]
499499
if is_dataclass(field_type) and isinstance(value, dict):
500500
return _deserialize_dataclass(field_type, value)
501501

0 commit comments

Comments
 (0)