Skip to content

Commit d178a88

Browse files
committed
fix
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 20d3ba7 commit d178a88

2 files changed

Lines changed: 196 additions & 11 deletions

File tree

tests/test_tensor_utils.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for transfer_queue.utils.tensor_utils."""
17+
18+
import pytest
19+
import torch
20+
21+
from transfer_queue.utils.tensor_utils import (
22+
allocate_empty_tensors,
23+
compute_stride,
24+
get_nbytes,
25+
merge_contiguous_memory,
26+
)
27+
28+
29+
class TestComputeStride:
30+
"""Tests for compute_stride."""
31+
32+
def test_3d(self):
33+
assert compute_stride((2, 3, 4)) == (12, 4, 1)
34+
35+
def test_1d(self):
36+
assert compute_stride((5,)) == (1,)
37+
38+
def test_scalar(self):
39+
assert compute_stride(()) == ()
40+
41+
def test_2d(self):
42+
assert compute_stride((3, 5)) == (5, 1)
43+
44+
45+
class TestGetNbytes:
46+
"""Tests for get_nbytes."""
47+
48+
def test_basic(self):
49+
dtypes = [torch.float32, torch.int32]
50+
shapes = [(2, 3), (4,)]
51+
result = get_nbytes(dtypes, shapes)
52+
assert result == [2 * 3 * 4, 4 * 4] # float32=4, int32=4
53+
54+
def test_scalar(self):
55+
dtypes = [torch.float64]
56+
shapes = [()]
57+
result = get_nbytes(dtypes, shapes)
58+
assert result == [8] # scalar = 1 element
59+
60+
def test_list_shape(self):
61+
dtypes = [torch.float32]
62+
shapes = [[]] # list instead of tuple
63+
result = get_nbytes(dtypes, shapes)
64+
assert result == [4]
65+
66+
def test_mixed_dtypes(self):
67+
dtypes = [torch.float16, torch.float32, torch.int64]
68+
shapes = [(10,), (10,), (10,)]
69+
result = get_nbytes(dtypes, shapes)
70+
assert result == [10 * 2, 10 * 4, 10 * 8]
71+
72+
73+
class TestAllocateEmptyTensors:
74+
"""Tests for allocate_empty_tensors."""
75+
76+
def test_basic(self):
77+
dtypes = [torch.float32, torch.float32, torch.int32]
78+
shapes = [(2, 3), (4,), (5,)]
79+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
80+
81+
assert len(tensors) == 3
82+
assert len(ptrs) == 3
83+
assert len(region_ptrs) == 2 # float32 group + int32 group
84+
assert len(region_sizes) == 2
85+
86+
# Same dtype tensors share the same underlying storage
87+
assert tensors[0].untyped_storage().data_ptr() == region_ptrs[0]
88+
assert tensors[1].untyped_storage().data_ptr() == region_ptrs[0]
89+
assert tensors[2].untyped_storage().data_ptr() == region_ptrs[1]
90+
91+
# Shapes are correct
92+
assert list(tensors[0].shape) == [2, 3]
93+
assert list(tensors[1].shape) == [4]
94+
assert list(tensors[2].shape) == [5]
95+
96+
def test_scalar(self):
97+
dtypes = [torch.float32, torch.int32]
98+
shapes = [(), ()]
99+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
100+
101+
assert len(tensors) == 2
102+
assert tensors[0].numel() == 1
103+
assert tensors[1].numel() == 1
104+
assert len(region_ptrs) == 2
105+
106+
def test_empty(self):
107+
result = allocate_empty_tensors([], [])
108+
assert result == ([], [], [], [])
109+
110+
def test_regions_complex(self):
111+
"""Mixed dtypes and shapes: verify region counts, sizes, and per-tensor offsets."""
112+
dtypes = [
113+
torch.float32, # group 0: (2, 3) -> 6 elements
114+
torch.int32, # group 1: (4,) -> 4 elements
115+
torch.float32, # group 0: scalar -> 1 element
116+
torch.float64, # group 2: (2, 2) -> 4 elements
117+
torch.int32, # group 1: (3, 2) -> 6 elements
118+
]
119+
shapes = [(2, 3), (4,), (), (2, 2), (3, 2)]
120+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
121+
122+
# 3 dtype groups in insertion order: float32, int32, float64
123+
assert len(region_ptrs) == 3
124+
assert len(region_sizes) == 3
125+
assert len(set(region_ptrs)) == 3 # distinct allocations
126+
127+
# float32 region: 6 + 1 = 7 elements * 4 bytes = 28 bytes
128+
assert region_sizes[0] == 7 * 4
129+
# int32 region: 4 + 6 = 10 elements * 4 bytes = 40 bytes
130+
assert region_sizes[1] == 10 * 4
131+
# float64 region: 4 elements * 8 bytes = 32 bytes
132+
assert region_sizes[2] == 4 * 8
133+
134+
# Per-tensor ptrs must lie inside their respective regions
135+
# tensor 0 (float32, shape (2,3), offset 0)
136+
assert ptrs[0] == region_ptrs[0]
137+
# tensor 1 (int32, shape (4,), offset 0)
138+
assert ptrs[1] == region_ptrs[1]
139+
# tensor 2 (float32, scalar, offset 6)
140+
assert ptrs[2] == region_ptrs[0] + 6 * 4
141+
# tensor 3 (float64, shape (2,2), offset 0)
142+
assert ptrs[3] == region_ptrs[2]
143+
# tensor 4 (int32, shape (3,2), offset 4)
144+
assert ptrs[4] == region_ptrs[1] + 4 * 4
145+
146+
147+
class TestMergeContiguousMemory:
148+
"""Tests for merge_contiguous_memory."""
149+
150+
def test_basic_merge(self):
151+
ptrs = [0, 10, 30]
152+
sizes = [10, 20, 10]
153+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
154+
# 0+10=10 (contiguous with 10), 10+20=30 (contiguous with 30) -> all merge into [0]
155+
assert merged_ptrs == [0]
156+
assert merged_sizes == [40]
157+
158+
def test_no_contiguous(self):
159+
ptrs = [0, 100, 200]
160+
sizes = [50, 50, 50]
161+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
162+
assert merged_ptrs == [0, 100, 200]
163+
assert merged_sizes == [50, 50, 50]
164+
165+
def test_unsorted_input(self):
166+
ptrs = [100, 0, 50]
167+
sizes = [50, 50, 50]
168+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
169+
# After sorting: 0, 50, 100; all contiguous -> merge into [0]
170+
assert merged_ptrs == [0]
171+
assert merged_sizes == [150]
172+
173+
def test_single_region(self):
174+
ptrs = [10]
175+
sizes = [100]
176+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
177+
assert merged_ptrs == [10]
178+
assert merged_sizes == [100]
179+
180+
def test_empty(self):
181+
assert merge_contiguous_memory([], []) == ([], [])
182+
183+
def test_mismatched_lengths_both_empty_not_triggered(self):
184+
# If one is empty and other is not, should raise ValueError
185+
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
186+
merge_contiguous_memory([], [10])
187+
188+
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
189+
merge_contiguous_memory([0], [])
190+
191+
def test_three_continuous(self):
192+
ptrs = [0, 10, 20]
193+
sizes = [10, 10, 10]
194+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
195+
assert merged_ptrs == [0]
196+
assert merged_sizes == [30]

transfer_queue/utils/tensor_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import logging
1717
import operator
1818
import os
19-
import warnings
2019
from functools import reduce
2120

2221
import torch
@@ -181,13 +180,3 @@ def merge_contiguous_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int
181180
merged_sizes.append(current_size)
182181

183182
return merged_ptrs, merged_sizes
184-
185-
186-
def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]:
187-
"""Deprecated alias for :func:`merge_contiguous_memory`."""
188-
warnings.warn(
189-
"merge_continues_memory is deprecated, use merge_contiguous_memory instead",
190-
DeprecationWarning,
191-
stacklevel=2,
192-
)
193-
return merge_contiguous_memory(ptrs, sizes)

0 commit comments

Comments
 (0)