Skip to content

Commit 79d1e67

Browse files
committed
set target-verion py310 and enable modern type annotation rules (Optional/Union → X | None)
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent 0f5649a commit 79d1e67

16 files changed

Lines changed: 126 additions & 131 deletions

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ requires-python = ">=3.10"
3131
# Note: While the formatter will attempt to format lines such that they remain within the line-length,
3232
# it isn't a hard upper bound, and formatted lines may exceed the line-length.
3333
line-length = 120
34+
target-version = "py310"
3435

3536
[tool.ruff.lint]
3637
isort = {known-first-party = ["transfer_queue"]}
@@ -60,7 +61,7 @@ ignore = [
6061
# `.log()` statement uses f-string
6162
"G004",
6263
# X | None for type annotations
63-
"UP045",
64+
# "UP045",
6465
# deprecated import
6566
"UP035",
6667
]

transfer_queue/client.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import asyncio
1717
import os
1818
import threading
19-
from typing import Any, Callable, Optional
19+
from typing import Any, Callable
2020

2121
import torch
2222
import zmq
@@ -104,9 +104,9 @@ async def async_get_meta(
104104
batch_size: int,
105105
partition_id: str,
106106
mode: str = "fetch",
107-
task_name: Optional[str] = None,
108-
sampling_config: Optional[dict[str, Any]] = None,
109-
socket: Optional[zmq.asyncio.Socket] = None,
107+
task_name: str | None = None,
108+
sampling_config: dict[str, Any] | None = None,
109+
socket: zmq.asyncio.Socket | None = None,
110110
) -> BatchMeta:
111111
"""Asynchronously fetch data metadata from the controller via ZMQ.
112112
@@ -191,7 +191,7 @@ async def async_get_meta(
191191
async def async_set_custom_meta(
192192
self,
193193
metadata: BatchMeta,
194-
socket: Optional[zmq.asyncio.Socket] = None,
194+
socket: zmq.asyncio.Socket | None = None,
195195
) -> None:
196196
"""
197197
Asynchronously send custom metadata to the controller.
@@ -264,9 +264,9 @@ async def async_set_custom_meta(
264264
async def async_put(
265265
self,
266266
data: TensorDict,
267-
metadata: Optional[BatchMeta] = None,
268-
partition_id: Optional[str] = None,
269-
data_parser: Optional[Callable[[Any], Any]] = None,
267+
metadata: BatchMeta | None = None,
268+
partition_id: str | None = None,
269+
data_parser: Callable[[Any], Any] | None = None,
270270
) -> BatchMeta:
271271
"""Asynchronously write data to storage units based on metadata.
272272
@@ -575,8 +575,8 @@ async def async_get_consumption_status(
575575
self,
576576
task_name: str,
577577
partition_id: str,
578-
socket: Optional[zmq.asyncio.Socket] = None,
579-
) -> tuple[Optional[Tensor], Optional[Tensor]]:
578+
socket: zmq.asyncio.Socket | None = None,
579+
) -> tuple[Tensor | None, Tensor | None]:
580580
"""Get consumption status for current partition in a specific task.
581581
582582
Args:
@@ -638,8 +638,8 @@ async def async_get_production_status(
638638
self,
639639
data_fields: list[str],
640640
partition_id: str,
641-
socket: Optional[zmq.asyncio.Socket] = None,
642-
) -> tuple[Optional[Tensor], Optional[Tensor]]:
641+
socket: zmq.asyncio.Socket | None = None,
642+
) -> tuple[Tensor | None, Tensor | None]:
643643
"""Get production status for specific data fields and partition.
644644
645645
Args:
@@ -769,8 +769,8 @@ async def async_check_production_status(
769769
async def async_reset_consumption(
770770
self,
771771
partition_id: str,
772-
task_name: Optional[str] = None,
773-
socket: Optional[zmq.asyncio.Socket] = None,
772+
task_name: str | None = None,
773+
socket: zmq.asyncio.Socket | None = None,
774774
) -> bool:
775775
"""Asynchronously reset consumption status for a partition.
776776
@@ -830,7 +830,7 @@ async def async_reset_consumption(
830830
@with_controller_socket
831831
async def async_get_partition_list(
832832
self,
833-
socket: Optional[zmq.asyncio.Socket] = None,
833+
socket: zmq.asyncio.Socket | None = None,
834834
) -> list[str]:
835835
"""Asynchronously fetch the list of partition ids from the controller.
836836
@@ -879,7 +879,7 @@ async def async_kv_retrieve_meta(
879879
keys: list[str] | str,
880880
partition_id: str,
881881
create: bool = False,
882-
socket: Optional[zmq.asyncio.Socket] = None,
882+
socket: zmq.asyncio.Socket | None = None,
883883
) -> BatchMeta:
884884
"""Asynchronously retrieve BatchMeta from the controller using user-specified keys.
885885
@@ -944,7 +944,7 @@ async def async_kv_retrieve_keys(
944944
self,
945945
global_indexes: list[int] | int,
946946
partition_id: str,
947-
socket: Optional[zmq.asyncio.Socket] = None,
947+
socket: zmq.asyncio.Socket | None = None,
948948
) -> list[str]:
949949
"""Asynchronously retrieve keys according to global_indexes from the controller.
950950
@@ -1005,8 +1005,8 @@ async def async_kv_retrieve_keys(
10051005
@with_controller_socket
10061006
async def async_kv_list(
10071007
self,
1008-
partition_id: Optional[str] = None,
1009-
socket: Optional[zmq.asyncio.Socket] = None,
1008+
partition_id: str | None = None,
1009+
socket: zmq.asyncio.Socket | None = None,
10101010
) -> dict[str, dict[str, Any]]:
10111011
"""Asynchronously retrieve keys and custom_meta from the controller for one or all partitions.
10121012
@@ -1145,8 +1145,8 @@ def get_meta(
11451145
batch_size: int,
11461146
partition_id: str,
11471147
mode: str = "fetch",
1148-
task_name: Optional[str] = None,
1149-
sampling_config: Optional[dict[str, Any]] = None,
1148+
task_name: str | None = None,
1149+
sampling_config: dict[str, Any] | None = None,
11501150
) -> BatchMeta:
11511151
"""Synchronously fetch data metadata from the controller via ZMQ.
11521152
@@ -1234,9 +1234,9 @@ def set_custom_meta(self, metadata: BatchMeta) -> None:
12341234
def put(
12351235
self,
12361236
data: TensorDict,
1237-
metadata: Optional[BatchMeta] = None,
1238-
partition_id: Optional[str] = None,
1239-
data_parser: Optional[Callable[[Any], Any]] = None,
1237+
metadata: BatchMeta | None = None,
1238+
partition_id: str | None = None,
1239+
data_parser: Callable[[Any], Any] | None = None,
12401240
) -> BatchMeta:
12411241
"""Synchronously write data to storage units based on metadata.
12421242
@@ -1356,7 +1356,7 @@ def get_consumption_status(
13561356
self,
13571357
task_name: str,
13581358
partition_id: str,
1359-
) -> tuple[Optional[Tensor], Optional[Tensor]]:
1359+
) -> tuple[Tensor | None, Tensor | None]:
13601360
"""Synchronously get consumption status for a specific task and partition.
13611361
13621362
Args:
@@ -1384,7 +1384,7 @@ def get_production_status(
13841384
self,
13851385
data_fields: list[str],
13861386
partition_id: str,
1387-
) -> tuple[Optional[Tensor], Optional[Tensor]]:
1387+
) -> tuple[Tensor | None, Tensor | None]:
13881388
"""Synchronously get production status for specific data fields and partition.
13891389
13901390
Args:
@@ -1454,7 +1454,7 @@ def check_production_status(self, data_fields: list[str], partition_id: str) ->
14541454
"""
14551455
return self._check_production_status(data_fields=data_fields, partition_id=partition_id)
14561456

1457-
def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool:
1457+
def reset_consumption(self, partition_id: str, task_name: str | None = None) -> bool:
14581458
"""Synchronously reset consumption status for a partition.
14591459
14601460
This allows the same data to be re-consumed, useful for debugging scenarios
@@ -1540,7 +1540,7 @@ def kv_retrieve_keys(
15401540

15411541
def kv_list(
15421542
self,
1543-
partition_id: Optional[str] = None,
1543+
partition_id: str | None = None,
15441544
) -> dict[str, dict[str, Any]]:
15451545
"""Synchronously retrieve keys and custom_meta from the controller for one or all partitions.
15461546

transfer_queue/controller.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from itertools import groupby
2222
from operator import itemgetter
2323
from threading import Lock, Thread
24-
from typing import Any, Optional
24+
from typing import Any
2525
from uuid import uuid4
2626

2727
import numpy as np
@@ -195,10 +195,10 @@ class FieldMeta:
195195
"""
196196

197197
global_indexes: set[int] = field(default_factory=set)
198-
dtype: Optional[Any] = None
199-
shape: Optional[tuple] = None # None when is_nested=True
200-
is_nested: Optional[bool] = None
201-
is_non_tensor: Optional[bool] = None
198+
dtype: Any | None = None
199+
shape: tuple | None = None # None when is_nested=True
200+
is_nested: bool | None = None
201+
is_non_tensor: bool | None = None
202202

203203
per_sample_shapes: dict[int, tuple] = field(default_factory=dict) # {global_idx: shape}
204204

@@ -495,7 +495,7 @@ def update_production_status(
495495
global_indices: list[int],
496496
field_names: list[str],
497497
field_schema: dict[str, dict[str, Any]],
498-
custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None,
498+
custom_backend_meta: dict[int, dict[str, Any]] | None = None,
499499
) -> bool:
500500
"""
501501
Update production status for specific samples and fields.
@@ -560,7 +560,7 @@ def _update_field_metadata(
560560
self,
561561
global_indexes: list[int],
562562
field_schema: dict[str, dict[str, Any]],
563-
custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None,
563+
custom_backend_meta: dict[int, dict[str, Any]] | None = None,
564564
):
565565
"""Update field metadata from columnar field_schema."""
566566
if not global_indexes:
@@ -645,7 +645,7 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
645645
consumption_status = self.consumption_status[task_name]
646646
return partition_global_index, consumption_status
647647

648-
def reset_consumption(self, task_name: Optional[str] = None):
648+
def reset_consumption(self, task_name: str | None = None):
649649
"""
650650
Reset consumption status for a specific task or all tasks.
651651
@@ -670,7 +670,7 @@ def reset_consumption(self, task_name: Optional[str] = None):
670670
# ==================== Production Status Interface ====================
671671
def get_production_status_for_fields(
672672
self, field_names: list[str], mask: bool = False
673-
) -> tuple[Optional[Tensor], Optional[Tensor]]:
673+
) -> tuple[Tensor | None, Tensor | None]:
674674
"""
675675
Check if all samples for specified fields are fully produced and ready.
676676
@@ -1049,7 +1049,7 @@ def create_partition(self, partition_id: str) -> bool:
10491049
logger.info(f"Created partition {partition_id} with {TQ_PRE_ALLOC_SAMPLE_NUM} pre-allocated indexes")
10501050
return True
10511051

1052-
def _get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]:
1052+
def _get_partition(self, partition_id: str) -> DataPartitionStatus | None:
10531053
"""
10541054
Get partition status information.
10551055
@@ -1061,7 +1061,7 @@ def _get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]:
10611061
"""
10621062
return self.partitions.get(partition_id)
10631063

1064-
def get_partition_snapshot(self, partition_id: str) -> Optional[DataPartitionStatus]:
1064+
def get_partition_snapshot(self, partition_id: str) -> DataPartitionStatus | None:
10651065
"""
10661066
Get a copy of partition status information, without threading.Lock().
10671067
@@ -1109,7 +1109,7 @@ def update_production_status(
11091109
partition_id: str,
11101110
global_indexes: list[int],
11111111
field_schema: dict[str, dict[str, Any]],
1112-
custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None,
1112+
custom_backend_meta: dict[int, dict[str, Any]] | None = None,
11131113
) -> bool:
11141114
"""
11151115
Update production status for specific samples and fields in a partition.
@@ -1139,7 +1139,7 @@ def update_production_status(
11391139
return success
11401140

11411141
# ==================== Data Consumption API ====================
1142-
def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]:
1142+
def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Tensor | None, Tensor | None]:
11431143
"""
11441144
Get or create consumption status for a specific task and partition.
11451145
Delegates to the partition's own method.
@@ -1159,9 +1159,7 @@ def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Opt
11591159

11601160
return partition.get_consumption_status(task_name, mask=True)
11611161

1162-
def get_production_status(
1163-
self, partition_id: str, data_fields: list[str]
1164-
) -> tuple[Optional[Tensor], Optional[Tensor]]:
1162+
def get_production_status(self, partition_id: str, data_fields: list[str]) -> tuple[Tensor | None, Tensor | None]:
11651163
"""
11661164
Check if all samples for specified fields are fully produced in a partition.
11671165
@@ -1226,7 +1224,7 @@ def get_metadata(
12261224
mode: str = "fetch",
12271225
task_name: str | None = None,
12281226
batch_size: int | None = None,
1229-
sampling_config: Optional[dict[str, Any]] = None,
1227+
sampling_config: dict[str, Any] | None = None,
12301228
*args,
12311229
**kwargs,
12321230
) -> BatchMeta:
@@ -1494,7 +1492,7 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True):
14941492
self.partitions.pop(partition_id)
14951493
self.sampler.clear_cache(partition_id)
14961494

1497-
def reset_consumption(self, partition_id: str, task_name: Optional[str] = None):
1495+
def reset_consumption(self, partition_id: str, task_name: str | None = None):
14981496
"""
14991497
Reset consumption status for a partition without clearing the actual data.
15001498
@@ -1641,7 +1639,7 @@ def kv_retrieve_keys(
16411639
self,
16421640
global_indexes: list[int],
16431641
partition_id: str,
1644-
) -> list[Optional[str]]:
1642+
) -> list[str | None]:
16451643
"""
16461644
Retrieve keys from the controller using a list of global_indexes.
16471645

transfer_queue/dataloader/streaming_dataloader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional
1716

1817
import torch
1918
from tensordict import TensorDict
@@ -80,7 +79,7 @@ def __init__(
8079
pin_memory: bool = False,
8180
worker_init_fn=None,
8281
multiprocessing_context=None,
83-
prefetch_factor: Optional[int] = None,
82+
prefetch_factor: int | None = None,
8483
persistent_workers: bool = False,
8584
pin_memory_device: str = "",
8685
):

0 commit comments

Comments
 (0)