1616import logging
1717import operator
1818import os
19+ import warnings
1920from functools import reduce
2021
2122import torch
2526logger .setLevel (os .getenv ("TQ_LOGGING_LEVEL" , logging .WARNING ))
2627
2728
28- def allocate_empty_tensors (dtypes : list [torch .dtype ], shapes : list [tuple ]) -> tuple [list [Tensor ], list [int ]]:
29+ def allocate_empty_tensors (
30+ dtypes : list [torch .dtype ], shapes : list [tuple ]
31+ ) -> tuple [list [Tensor ], list [int ], list [int ], list [int ]]:
2932 """Allocate empty tensors, grouping same dtypes into shared memory blocks.
3033
3134 Instead of allocating each tensor separately, this function groups tensors
@@ -40,17 +43,19 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
4043 A tuple containing:
4144 - List of tensors sharing memory within their dtype groups.
4245 - List of memory pointers (data_ptr) for each tensor.
46+ - List of base pointers for each allocated memory region (one per dtype).
47+ - List of total bytes for each allocated memory region (one per dtype).
4348
4449 Example:
4550 >>> dtypes = [torch.float32, torch.float32, torch.int32, torch.float32]
4651 >>> shapes = [(10,), (20,), (5,), (15,)]
47- >>> tensors, ptrs = allocate_empty_tensors(dtypes, shapes)
52+ >>> tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
4853 >>> # tensors[0], [1], [3] share the same dtype and memory block
4954 """
5055 assert len (dtypes ) == len (shapes ), "dtypes and shapes must have the same length"
5156
5257 if len (dtypes ) == 0 :
53- return [], []
58+ return [], [], [], []
5459
5560 # Group indices by dtype
5661 dtype_groups : dict [torch .dtype , list [int ]] = {}
@@ -61,6 +66,8 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
6166
6267 tensor_list = [torch .empty (()) for _ in range (len (dtypes ))]
6368 ptr_list = [0 ] * len (dtypes )
69+ region_ptrs : list [int ] = []
70+ region_sizes : list [int ] = []
6471
6572 # For each dtype group, allocate one big tensor and create views
6673 for dtype , indices in dtype_groups .items ():
@@ -69,13 +76,15 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
6976 shape_info = [] # Store (index, shape, num_elements, offset)
7077
7178 for idx in indices :
72- shape = shapes [idx ]
73- num_elements = reduce (operator .mul , shape )
79+ shape = tuple ( shapes [idx ])
80+ num_elements = reduce (operator .mul , shape , 1 )
7481 shape_info .append ((idx , shape , num_elements , total_elements ))
7582 total_elements += num_elements
7683
7784 # Allocate one big contiguous memory block for this dtype
7885 big_tensor = torch .empty (total_elements , dtype = dtype )
86+ region_ptrs .append (big_tensor .data_ptr ())
87+ region_sizes .append (big_tensor .nbytes )
7988
8089 # Create views into the big tensor for each small tensor
8190 for idx , shape , num_elements , offset in shape_info :
@@ -84,7 +93,7 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
8493 tensor_list [idx ] = small_tensor
8594 ptr_list [idx ] = small_tensor .data_ptr ()
8695
87- return tensor_list , ptr_list
96+ return tensor_list , ptr_list , region_ptrs , region_sizes
8897
8998
9099def compute_stride (shape : tuple [int , ...]) -> tuple [int , ...]:
@@ -115,36 +124,37 @@ def get_nbytes(dtypes, shapes) -> list[int]:
115124 nbytes = []
116125 for i in range (len (dtypes )):
117126 elem_size = torch .tensor ([], dtype = dtypes [i ]).element_size ()
118- numel = reduce (operator .mul , shapes [i ])
127+ shape = tuple (shapes [i ])
128+ numel = reduce (operator .mul , shape , 1 )
119129 nbytes .append (elem_size * numel )
120130
121131 return nbytes
122132
123133
124- def merge_continues_memory (ptrs : list [int ], sizes : list [int ]) -> tuple [list [int ], list [int ]]:
125- """Merge continuous memory regions to reduce register_buffer overhead
134+ def merge_contiguous_memory (ptrs : list [int ], sizes : list [int ]) -> tuple [list [int ], list [int ]]:
135+ """Merge contiguous memory regions to reduce register_buffer overhead
126136
127137 Args:
128138 ptrs: List of memory pointers (starting addresses).
129139 sizes: List of memory region sizes corresponding to each pointer.
130140
131141 Returns:
132- A tuple of (merged_ptrs, merged_sizes) where continuous regions
142+ A tuple of (merged_ptrs, merged_sizes) where contiguous regions
133143 have been merged into single regions.
134144
135145 Example:
136- >>> merge_continues_memory ([0, 10, 30], [10, 20, 10])
146+ >>> merge_contiguous_memory ([0, 10, 30], [10, 20, 10])
137147 ([0, 30], [30, 10])
138148
139- >>> merge_continues_memory ([0, 5, 20], [5, 5, 10])
149+ >>> merge_contiguous_memory ([0, 5, 20], [5, 5, 10])
140150 ([0, 20], [10, 10])
141151 """
142- if not ptrs or not sizes :
143- return [], []
144-
145152 if len (ptrs ) != len (sizes ):
146153 raise ValueError ("ptrs and sizes must have the same length" )
147154
155+ if not ptrs :
156+ return [], []
157+
148158 # Create list of (ptr, size) pairs and sort by pointer address
149159 regions = sorted (zip (ptrs , sizes , strict = False ), key = lambda x : x [0 ])
150160
@@ -171,3 +181,13 @@ def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int]
171181 merged_sizes .append (current_size )
172182
173183 return merged_ptrs , merged_sizes
184+
185+
186+ def merge_continues_memory (ptrs : list [int ], sizes : list [int ]) -> tuple [list [int ], list [int ]]:
187+ """Deprecated alias for :func:`merge_contiguous_memory`."""
188+ warnings .warn (
189+ "merge_continues_memory is deprecated, use merge_contiguous_memory instead" ,
190+ DeprecationWarning ,
191+ stacklevel = 2 ,
192+ )
193+ return merge_contiguous_memory (ptrs , sizes )
0 commit comments