Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest build pytest_asyncio pytest-mock openyuanrong-datasystem
python -m build --wheel
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install dist/*.whl
pip install -e ".[test,build,yuanrong]"
Comment thread
0oshowero0 marked this conversation as resolved.
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test Build
run: |
python -m build --wheel
- name: Test with pytest
run: |
Comment thread
0oshowero0 marked this conversation as resolved.
Outdated
pytest
7 changes: 0 additions & 7 deletions .github/workflows/sanity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ jobs:
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build
python -m build --wheel
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install dist/*.whl
- name: Run license test
run: |
python3 tests/sanity/check_license.py --directories .
Expand Down
17 changes: 13 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ filterwarnings = [

[[tool.mypy.overrides]]
module = [
"transfer_queue.data_system.*",
"transfer_queue.utils.utils.*",
"transfer_queue.utils.zmq_utils.*",
"transfer_queue.utils.serial_utils.*",
"transfer_queue.*",
]
ignore_errors = false

Expand All @@ -108,11 +105,23 @@ version = {file = "transfer_queue/version/version"}
dependencies = {file = "requirements.txt"}

[project.optional-dependencies]

Comment thread
0oshowero0 marked this conversation as resolved.
Outdated
build = [
"build"
]

test = [
"pytest>=7.0.0",
"pytest-asyncio>=0.20.0",
"flake8",
"pytest-mock",
]

yuanrong = [
"openyuanrong-datasystem"
]


Comment thread
0oshowero0 marked this conversation as resolved.
Outdated
# If you need to mimic `package_dir={'': '.'}`:
[tool.setuptools.package-dir]
"" = "."
Expand Down
12 changes: 6 additions & 6 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_controller_with_single_partition(self, ray_setup):
ProductionStatus.NOT_PRODUCED
)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
assert partition_index_range == set(range(gbs * num_n_samples))
assert partition_index_range == list(range(gbs * num_n_samples))

print("✓ Initial get metadata correct")

Expand Down Expand Up @@ -194,7 +194,7 @@ def test_controller_with_single_partition(self, ray_setup):
ray.get(tq_controller.clear_partition.remote(partition_id))
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
assert partition_index_range == set()
assert partition_index_range == []
assert partition is None
print("✓ Clear partition correct")

Expand Down Expand Up @@ -307,7 +307,7 @@ def test_controller_with_multi_partitions(self, ray_setup):
[int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples]
) == int(ProductionStatus.NOT_PRODUCED)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
assert partition_index_range == set(range(part1_index_range, part2_index_range + part1_index_range))
assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range))

# Update production status
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in val_metadata.global_indexes}
Expand Down Expand Up @@ -359,11 +359,11 @@ def test_controller_with_multi_partitions(self, ray_setup):

assert not partition_index_range_1_after_clear
assert partition_1_after_clear is None
assert partition_index_range_1_after_clear == set()
assert partition_index_range_1_after_clear == []

partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2))
partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
assert partition_index_range_2 == [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
assert torch.all(
partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1
)
Expand All @@ -387,7 +387,7 @@ def test_controller_with_multi_partitions(self, ray_setup):
[int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples]
) == int(ProductionStatus.NOT_PRODUCED)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
assert partition_index_range == set(list(range(32)) + list(range(48, 80)))
assert partition_index_range == list(range(32)) + list(range(48, 80))
print("✓ Correctly assign partition_3")

def test_controller_clear_meta(self, ray_setup):
Expand Down
14 changes: 10 additions & 4 deletions tests/test_kv_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def get_meta(data, global_indexes=None):
@pytest.fixture
def test_data():
"""Fixture providing test configuration, data, and metadata."""
cfg = {"client_name": "YuanrongStorageClient", "host": "127.0.0.1", "port": 31501, "device_id": 0}
cfg = {
"controller_info": MagicMock(),
"client_name": "YuanrongStorageClient",
"host": "127.0.0.1",
"port": 31501,
"device_id": 0,
}
global_indexes = [8, 9, 10]

data = TensorDict(
Expand Down Expand Up @@ -288,7 +294,7 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
mock_storage_client.put.return_value = mock_custom_meta

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)

Expand Down Expand Up @@ -338,7 +344,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data):
mock_storage_client.put.return_value = None

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)

Expand All @@ -361,7 +367,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat
mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}]

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)

Expand Down
5 changes: 3 additions & 2 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ async def async_get_partition_list(
)

try:
assert socket is not None
await socket.send_multipart(request_msg.serialize())
response_serialized = await socket.recv_multipart()
response_msg = ZMQMessage.deserialize(response_serialized)
Expand Down Expand Up @@ -991,10 +992,10 @@ def process_zmq_server_info(
>>> info_dict = process_zmq_server_info(handlers)"""
# Handle single handler object case
if not isinstance(handlers, dict):
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[attr-defined]
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
else:
# Handle dictionary case
server_info = {}
for name, handler in handlers.items():
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined]
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
return server_info
25 changes: 15 additions & 10 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,17 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]):
if not partition_indexes:
self.partition_to_indexes.pop(partition_id, None)

def get_indexes_for_partition(self, partition_id) -> set[int]:
def get_indexes_for_partition(self, partition_id) -> list[int]:
"""
Get all global_indexes for the specified partition.

Args:
partition_id: Partition ID

Returns:
set: Set of global_indexes for this partition
list: List of global_indexes for this partition
"""
return self.partition_to_indexes.get(partition_id, set()).copy()
return list(self.partition_to_indexes.get(partition_id, set()).copy())


@dataclass
Expand All @@ -216,7 +216,7 @@ class DataPartitionStatus:

# Production status tensor - dynamically expandable
# Values: 0 = not produced, 1 = ready for consumption
production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
production_status: Tensor = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
Comment thread
0oshowero0 marked this conversation as resolved.
Comment thread
0oshowero0 marked this conversation as resolved.

# Consumption status per task - task_name -> consumption_tensor
# Each tensor tracks which samples have been consumed by that task
Expand Down Expand Up @@ -260,7 +260,7 @@ def allocated_samples_num(self) -> int:

# ==================== Dynamic Expansion Methods ====================

def ensure_samples_capacity(self, required_samples: int) -> bool:
def ensure_samples_capacity(self, required_samples: int):
Comment thread
0oshowero0 marked this conversation as resolved.
Outdated
"""
Ensure the production status tensor has enough rows for the required samples.
Dynamically expands if needed using unified minimum expansion size.
Expand Down Expand Up @@ -498,7 +498,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
return partition_global_index, consumption_status

# ==================== Production Status Interface ====================
def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]:
def get_production_status_for_fields(
self, field_names: list[str], mask: bool = False
) -> tuple[Optional[Tensor], Optional[Tensor]]:
"""
Check if all samples for specified fields are fully produced and ready.

Expand All @@ -512,12 +514,12 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool =
- Production status tensor for the specified task. 1 for ready, 0 for not ready.
"""
if self.production_status is None or field_names is None or len(field_names) == 0:
return False
return None, None

# Check if all requested fields are registered
for field_name in field_names:
if field_name not in self.field_name_mapping:
return False
return None, None

# Create column mask for requested fields
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
Expand Down Expand Up @@ -837,15 +839,15 @@ def list_partitions(self) -> list[str]:

# ==================== Partition Index Management API ====================

def get_partition_index_range(self, partition: DataPartitionStatus) -> set:
def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]:
"""
Get all indexes for a specific partition.

Args:
partition: Partition identifier

Returns:
Set of indexes allocated to the partition
List of indexes allocated to the partition
"""
return self.index_manager.get_indexes_for_partition(partition)

Expand Down Expand Up @@ -980,6 +982,9 @@ def get_metadata(
if mode == "fetch":
# Find ready samples within current data partition and package into BatchMeta when reading

if batch_size is None:
raise ValueError("must provide batch_size in fetch mode")

start_time = time.time()
while True:
# ready_for_consume_indexes: samples where all required fields are produced
Expand Down
1 change: 1 addition & 0 deletions transfer_queue/dataloader/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
if self._tq_client is None:
self._create_client()

assert self._tq_client is not None, "Failed to create TransferQueue client"
# TODO: need to consider async scenario where the samples in partition is dynamically increasing
while not self._tq_client.check_consumption_status(self.task_name, self.partition_id):
try:
Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]:
"""Get the entire custom meta dictionary"""
return copy.deepcopy(self._custom_meta)

def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None):
def update_custom_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]):
"""Update custom meta with a new dictionary"""
if new_custom_meta:
self._custom_meta.update(new_custom_meta)
Expand Down
9 changes: 9 additions & 0 deletions transfer_queue/storage/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ class TransferQueueStorageKVClient(ABC):
Subclasses must implement the core methods: put, get, and clear.
"""

@abstractmethod
def __init__(self, config: dict[str, Any]):
"""
Initialize the storage client with configuration.
Args:
config (dict[str, Any]): Configuration dictionary for the storage client.
"""
...
Comment thread
0oshowero0 marked this conversation as resolved.
Outdated

@abstractmethod
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
"""
Expand Down
5 changes: 3 additions & 2 deletions transfer_queue/storage/clients/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from transfer_queue.storage.clients.base import TransferQueueStorageKVClient


Expand All @@ -23,7 +24,7 @@ class StorageClientFactory:
"""

# Class variable: maps client names to their corresponding classes
_registry: dict[str, TransferQueueStorageKVClient] = {}
_registry: dict[str, type[TransferQueueStorageKVClient]] = {}

@classmethod
def register(cls, client_type: str):
Expand All @@ -35,7 +36,7 @@ def register(cls, client_type: str):
Callable: The decorator function that returns the original class
"""

def decorator(client_class: TransferQueueStorageKVClient) -> TransferQueueStorageKVClient:
def decorator(client_class: type[TransferQueueStorageKVClient]) -> type[TransferQueueStorageKVClient]:
cls._registry[client_type] = client_class
return client_class

Expand Down
Loading
Loading