|
16 | 16 | import asyncio |
17 | 17 | import os |
18 | 18 | import threading |
19 | | -from typing import Any, Callable, Optional |
| 19 | +from typing import Any, Callable |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | import zmq |
@@ -104,9 +104,9 @@ async def async_get_meta( |
104 | 104 | batch_size: int, |
105 | 105 | partition_id: str, |
106 | 106 | mode: str = "fetch", |
107 | | - task_name: Optional[str] = None, |
108 | | - sampling_config: Optional[dict[str, Any]] = None, |
109 | | - socket: Optional[zmq.asyncio.Socket] = None, |
| 107 | + task_name: str | None = None, |
| 108 | + sampling_config: dict[str, Any] | None = None, |
| 109 | + socket: zmq.asyncio.Socket | None = None, |
110 | 110 | ) -> BatchMeta: |
111 | 111 | """Asynchronously fetch data metadata from the controller via ZMQ. |
112 | 112 |
|
@@ -191,7 +191,7 @@ async def async_get_meta( |
191 | 191 | async def async_set_custom_meta( |
192 | 192 | self, |
193 | 193 | metadata: BatchMeta, |
194 | | - socket: Optional[zmq.asyncio.Socket] = None, |
| 194 | + socket: zmq.asyncio.Socket | None = None, |
195 | 195 | ) -> None: |
196 | 196 | """ |
197 | 197 | Asynchronously send custom metadata to the controller. |
@@ -264,9 +264,9 @@ async def async_set_custom_meta( |
264 | 264 | async def async_put( |
265 | 265 | self, |
266 | 266 | data: TensorDict, |
267 | | - metadata: Optional[BatchMeta] = None, |
268 | | - partition_id: Optional[str] = None, |
269 | | - data_parser: Optional[Callable[[Any], Any]] = None, |
| 267 | + metadata: BatchMeta | None = None, |
| 268 | + partition_id: str | None = None, |
| 269 | + data_parser: Callable[[Any], Any] | None = None, |
270 | 270 | ) -> BatchMeta: |
271 | 271 | """Asynchronously write data to storage units based on metadata. |
272 | 272 |
|
@@ -575,8 +575,8 @@ async def async_get_consumption_status( |
575 | 575 | self, |
576 | 576 | task_name: str, |
577 | 577 | partition_id: str, |
578 | | - socket: Optional[zmq.asyncio.Socket] = None, |
579 | | - ) -> tuple[Optional[Tensor], Optional[Tensor]]: |
| 578 | + socket: zmq.asyncio.Socket | None = None, |
| 579 | + ) -> tuple[Tensor | None, Tensor | None]: |
580 | 580 | """Get consumption status for current partition in a specific task. |
581 | 581 |
|
582 | 582 | Args: |
@@ -638,8 +638,8 @@ async def async_get_production_status( |
638 | 638 | self, |
639 | 639 | data_fields: list[str], |
640 | 640 | partition_id: str, |
641 | | - socket: Optional[zmq.asyncio.Socket] = None, |
642 | | - ) -> tuple[Optional[Tensor], Optional[Tensor]]: |
| 641 | + socket: zmq.asyncio.Socket | None = None, |
| 642 | + ) -> tuple[Tensor | None, Tensor | None]: |
643 | 643 | """Get production status for specific data fields and partition. |
644 | 644 |
|
645 | 645 | Args: |
@@ -769,8 +769,8 @@ async def async_check_production_status( |
769 | 769 | async def async_reset_consumption( |
770 | 770 | self, |
771 | 771 | partition_id: str, |
772 | | - task_name: Optional[str] = None, |
773 | | - socket: Optional[zmq.asyncio.Socket] = None, |
| 772 | + task_name: str | None = None, |
| 773 | + socket: zmq.asyncio.Socket | None = None, |
774 | 774 | ) -> bool: |
775 | 775 | """Asynchronously reset consumption status for a partition. |
776 | 776 |
|
@@ -830,7 +830,7 @@ async def async_reset_consumption( |
830 | 830 | @with_controller_socket |
831 | 831 | async def async_get_partition_list( |
832 | 832 | self, |
833 | | - socket: Optional[zmq.asyncio.Socket] = None, |
| 833 | + socket: zmq.asyncio.Socket | None = None, |
834 | 834 | ) -> list[str]: |
835 | 835 | """Asynchronously fetch the list of partition ids from the controller. |
836 | 836 |
|
@@ -879,7 +879,7 @@ async def async_kv_retrieve_meta( |
879 | 879 | keys: list[str] | str, |
880 | 880 | partition_id: str, |
881 | 881 | create: bool = False, |
882 | | - socket: Optional[zmq.asyncio.Socket] = None, |
| 882 | + socket: zmq.asyncio.Socket | None = None, |
883 | 883 | ) -> BatchMeta: |
884 | 884 | """Asynchronously retrieve BatchMeta from the controller using user-specified keys. |
885 | 885 |
|
@@ -944,7 +944,7 @@ async def async_kv_retrieve_keys( |
944 | 944 | self, |
945 | 945 | global_indexes: list[int] | int, |
946 | 946 | partition_id: str, |
947 | | - socket: Optional[zmq.asyncio.Socket] = None, |
| 947 | + socket: zmq.asyncio.Socket | None = None, |
948 | 948 | ) -> list[str]: |
949 | 949 | """Asynchronously retrieve keys according to global_indexes from the controller. |
950 | 950 |
|
@@ -1005,8 +1005,8 @@ async def async_kv_retrieve_keys( |
1005 | 1005 | @with_controller_socket |
1006 | 1006 | async def async_kv_list( |
1007 | 1007 | self, |
1008 | | - partition_id: Optional[str] = None, |
1009 | | - socket: Optional[zmq.asyncio.Socket] = None, |
| 1008 | + partition_id: str | None = None, |
| 1009 | + socket: zmq.asyncio.Socket | None = None, |
1010 | 1010 | ) -> dict[str, dict[str, Any]]: |
1011 | 1011 | """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions. |
1012 | 1012 |
|
@@ -1145,8 +1145,8 @@ def get_meta( |
1145 | 1145 | batch_size: int, |
1146 | 1146 | partition_id: str, |
1147 | 1147 | mode: str = "fetch", |
1148 | | - task_name: Optional[str] = None, |
1149 | | - sampling_config: Optional[dict[str, Any]] = None, |
| 1148 | + task_name: str | None = None, |
| 1149 | + sampling_config: dict[str, Any] | None = None, |
1150 | 1150 | ) -> BatchMeta: |
1151 | 1151 | """Synchronously fetch data metadata from the controller via ZMQ. |
1152 | 1152 |
|
@@ -1234,9 +1234,9 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: |
1234 | 1234 | def put( |
1235 | 1235 | self, |
1236 | 1236 | data: TensorDict, |
1237 | | - metadata: Optional[BatchMeta] = None, |
1238 | | - partition_id: Optional[str] = None, |
1239 | | - data_parser: Optional[Callable[[Any], Any]] = None, |
| 1237 | + metadata: BatchMeta | None = None, |
| 1238 | + partition_id: str | None = None, |
| 1239 | + data_parser: Callable[[Any], Any] | None = None, |
1240 | 1240 | ) -> BatchMeta: |
1241 | 1241 | """Synchronously write data to storage units based on metadata. |
1242 | 1242 |
|
@@ -1356,7 +1356,7 @@ def get_consumption_status( |
1356 | 1356 | self, |
1357 | 1357 | task_name: str, |
1358 | 1358 | partition_id: str, |
1359 | | - ) -> tuple[Optional[Tensor], Optional[Tensor]]: |
| 1359 | + ) -> tuple[Tensor | None, Tensor | None]: |
1360 | 1360 | """Synchronously get consumption status for a specific task and partition. |
1361 | 1361 |
|
1362 | 1362 | Args: |
@@ -1384,7 +1384,7 @@ def get_production_status( |
1384 | 1384 | self, |
1385 | 1385 | data_fields: list[str], |
1386 | 1386 | partition_id: str, |
1387 | | - ) -> tuple[Optional[Tensor], Optional[Tensor]]: |
| 1387 | + ) -> tuple[Tensor | None, Tensor | None]: |
1388 | 1388 | """Synchronously get production status for specific data fields and partition. |
1389 | 1389 |
|
1390 | 1390 | Args: |
@@ -1454,7 +1454,7 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> |
1454 | 1454 | """ |
1455 | 1455 | return self._check_production_status(data_fields=data_fields, partition_id=partition_id) |
1456 | 1456 |
|
1457 | | - def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: |
| 1457 | + def reset_consumption(self, partition_id: str, task_name: str | None = None) -> bool: |
1458 | 1458 | """Synchronously reset consumption status for a partition. |
1459 | 1459 |
|
1460 | 1460 | This allows the same data to be re-consumed, useful for debugging scenarios |
@@ -1540,7 +1540,7 @@ def kv_retrieve_keys( |
1540 | 1540 |
|
1541 | 1541 | def kv_list( |
1542 | 1542 | self, |
1543 | | - partition_id: Optional[str] = None, |
| 1543 | + partition_id: str | None = None, |
1544 | 1544 | ) -> dict[str, dict[str, Any]]: |
1545 | 1545 | """Synchronously retrieve keys and custom_meta from the controller for one or all partitions. |
1546 | 1546 |
|
|
0 commit comments