|
| 1 | +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 2 | +# Copyright 2025 The TransferQueue Team |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import sys |
| 17 | +from unittest import mock |
| 18 | + |
| 19 | +import pytest |
| 20 | +import torch |
| 21 | + |
| 22 | +try: |
| 23 | + import torch_npu # noqa: F401 |
| 24 | +except ImportError: |
| 25 | + pass |
| 26 | + |
| 27 | + |
| 28 | +# --- Mock Backend Implementation --- |
| 29 | +# In real scenarios, multiple DsTensorClients or KVClients share storage. |
| 30 | +# Here, each mockClient is implemented with independent storage using a simple dictionary, |
| 31 | +# and is only suitable for unit testing. |
| 32 | + |
| 33 | + |
| 34 | +class MockDsTensorClient: |
| 35 | + def __init__(self, host, port, device_id): |
| 36 | + self.storage = {} |
| 37 | + |
| 38 | + def init(self): |
| 39 | + pass |
| 40 | + |
| 41 | + def dev_mset(self, keys, values): |
| 42 | + for k, v in zip(keys, values, strict=True): |
| 43 | + assert v.device.type == "npu" |
| 44 | + self.storage[k] = v |
| 45 | + |
| 46 | + def dev_mget(self, keys, out_tensors): |
| 47 | + for i, k in enumerate(keys): |
| 48 | + # Note: If key is missing, tensor remains unchanged (mock limitation) |
| 49 | + if k in self.storage: |
| 50 | + out_tensors[i].copy_(self.storage[k]) |
| 51 | + |
| 52 | + def dev_delete(self, keys): |
| 53 | + for k in keys: |
| 54 | + self.storage.pop(k, None) |
| 55 | + |
| 56 | + |
| 57 | +class MockKVClient: |
| 58 | + def __init__(self, host, port): |
| 59 | + self.storage = {} |
| 60 | + |
| 61 | + def init(self): |
| 62 | + pass |
| 63 | + |
| 64 | + def mcreate(self, keys, sizes): |
| 65 | + class MockBuffer: |
| 66 | + def __init__(self, size): |
| 67 | + self._data = bytearray(size) |
| 68 | + |
| 69 | + def MutableData(self): |
| 70 | + return memoryview(self._data) |
| 71 | + |
| 72 | + self._current_keys = keys |
| 73 | + return [MockBuffer(s) for s in sizes] |
| 74 | + |
| 75 | + def mset_buffer(self, buffers): |
| 76 | + for key, buf in zip(self._current_keys, buffers, strict=True): |
| 77 | + self.storage[key] = bytes(buf.MutableData()) |
| 78 | + |
| 79 | + def get_buffers(self, keys): |
| 80 | + return [memoryview(self.storage[k]) if k in self.storage else None for k in keys] |
| 81 | + |
| 82 | + def delete(self, keys): |
| 83 | + for k in keys: |
| 84 | + self.storage.pop(k, None) |
| 85 | + |
| 86 | + |
| 87 | +# --- Fixtures --- |
| 88 | + |
| 89 | + |
| 90 | +@pytest.fixture |
| 91 | +def mock_yr_datasystem(): |
| 92 | + """Wipe real 'yr' modules and inject mocks.""" |
| 93 | + |
| 94 | + # 1. Clean up sys.modules to force a fresh import under mock conditions |
| 95 | + # This ensures top-level code in yuanrong_client.py is re-executed |
| 96 | + to_delete = [k for k in sys.modules if k.startswith("yr")] |
| 97 | + for mod in to_delete: |
| 98 | + del sys.modules[mod] |
| 99 | + |
| 100 | + # 2. Setup Mock Objects |
| 101 | + ds_mock = mock.MagicMock() |
| 102 | + ds_mock.DsTensorClient = MockDsTensorClient |
| 103 | + ds_mock.KVClient = MockKVClient |
| 104 | + |
| 105 | + yr_mock = mock.MagicMock(datasystem=ds_mock) |
| 106 | + |
| 107 | + # 3. Apply patches |
| 108 | + # - sys.modules: Redirects 'import yr' to our mocks |
| 109 | + # - YUANRONG_DATASYSTEM_IMPORTED: Forces the existence check to True so initialize the client successfully |
| 110 | + # - datasystem: Direct attribute patch for the module |
| 111 | + with ( |
| 112 | + mock.patch.dict("sys.modules", {"yr": yr_mock, "yr.datasystem": ds_mock}), |
| 113 | + mock.patch("transfer_queue.storage.clients.yuanrong_client.YUANRONG_DATASYSTEM_IMPORTED", True, create=True), |
| 114 | + mock.patch("transfer_queue.storage.clients.yuanrong_client.datasystem", ds_mock), |
| 115 | + ): |
| 116 | + yield |
| 117 | + |
| 118 | + |
| 119 | +@pytest.fixture |
| 120 | +def config(): |
| 121 | + return {"host": "127.0.0.1", "port": 12345, "enable_yr_npu_optimization": True} |
| 122 | + |
| 123 | + |
| 124 | +def assert_tensors_equal(a: torch.Tensor, b: torch.Tensor): |
| 125 | + assert a.shape == b.shape and a.dtype == b.dtype |
| 126 | + # Move to CPU for cross-device comparison |
| 127 | + assert torch.equal(a.cpu(), b.cpu()) |
| 128 | + |
| 129 | + |
| 130 | +# --- Test Suite --- |
| 131 | + |
| 132 | + |
| 133 | +class TestYuanrongStorageE2E: |
| 134 | + @pytest.fixture(autouse=True) |
| 135 | + def setup_client(self, mock_yr_datasystem, config): |
| 136 | + # Lazy import to ensure mocks are active |
| 137 | + from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient |
| 138 | + |
| 139 | + self.client_cls = YuanrongStorageClient |
| 140 | + self.config = config |
| 141 | + |
| 142 | + def _create_data(self, mode="cpu"): |
| 143 | + if mode == "cpu": |
| 144 | + keys = ["t", "s", "i"] |
| 145 | + vals = [torch.randn(2), "hi", 1] |
| 146 | + elif mode == "npu": |
| 147 | + if not (hasattr(torch, "npu") and torch.npu.is_available()): |
| 148 | + pytest.skip("NPU required") |
| 149 | + keys = ["n1", "n2"] |
| 150 | + vals = [torch.randn(2).npu(), torch.tensor([1]).npu()] |
| 151 | + else: # mixed |
| 152 | + if not (hasattr(torch, "npu") and torch.npu.is_available()): |
| 153 | + pytest.skip("NPU required") |
| 154 | + keys = ["n1", "c1"] |
| 155 | + vals = [torch.randn(2).npu(), "cpu"] |
| 156 | + |
| 157 | + shapes = [list(v.shape) if isinstance(v, torch.Tensor) else [] for v in vals] |
| 158 | + dtypes = [v.dtype if isinstance(v, torch.Tensor) else None for v in vals] |
| 159 | + return keys, vals, shapes, dtypes |
| 160 | + |
| 161 | + def test_mock_can_work(self, config): |
| 162 | + mock_class = (MockDsTensorClient, MockKVClient) |
| 163 | + client = self.client_cls(config) |
| 164 | + for strategy in client._strategies: |
| 165 | + assert isinstance(strategy._ds_client, mock_class) |
| 166 | + |
| 167 | + def test_cpu_only_flow(self, config): |
| 168 | + client = self.client_cls(config) |
| 169 | + keys, vals, shp, dt = self._create_data("cpu") |
| 170 | + |
| 171 | + # Put & Verify Meta |
| 172 | + meta = client.put(keys, vals) |
| 173 | + # "2" is a tag added by YuanrongStorageClient, indicating that it is processed via General KV path. |
| 174 | + assert all(m == "2" for m in meta) |
| 175 | + |
| 176 | + # Get & Verify Values |
| 177 | + ret = client.get(keys, shp, dt, meta) |
| 178 | + for o, r in zip(vals, ret, strict=True): |
| 179 | + if isinstance(o, torch.Tensor): |
| 180 | + assert_tensors_equal(o, r) |
| 181 | + else: |
| 182 | + assert o == r |
| 183 | + |
| 184 | + # Clear & Verify |
| 185 | + client.clear(keys, meta) |
| 186 | + assert all(v is None for v in client.get(keys, shp, dt, meta)) |
| 187 | + |
| 188 | + def test_npu_only_flow(self, config): |
| 189 | + keys, vals, shp, dt = self._create_data("npu") |
| 190 | + client = self.client_cls(config) |
| 191 | + |
| 192 | + meta = client.put(keys, vals) |
| 193 | + # "1" is a tag added by YuanrongStorageClient, indicating that it is processed via NPU path. |
| 194 | + assert all(m == "1" for m in meta) |
| 195 | + |
| 196 | + ret = client.get(keys, shp, dt, meta) |
| 197 | + for o, r in zip(vals, ret, strict=True): |
| 198 | + assert_tensors_equal(o, r) |
| 199 | + |
| 200 | + client.clear(keys, meta) |
| 201 | + |
| 202 | + def test_mixed_flow(self, config): |
| 203 | + keys, vals, shp, dt = self._create_data("mixed") |
| 204 | + client = self.client_cls(config) |
| 205 | + |
| 206 | + meta = client.put(keys, vals) |
| 207 | + assert set(meta) == {"1", "2"} |
| 208 | + |
| 209 | + ret = client.get(keys, shp, dt, meta) |
| 210 | + for o, r in zip(vals, ret, strict=True): |
| 211 | + if isinstance(o, torch.Tensor): |
| 212 | + assert_tensors_equal(o, r) |
| 213 | + else: |
| 214 | + assert o == r |
0 commit comments