@@ -29,7 +29,9 @@ def __init__(self):
2929 self ._block_tensor_nums = {} # offload tensors per block
3030
3131 def get_cnt (self , block_idx ):
32+ prev_block_idx = None if self ._block_idx == - 1 else self ._block_idx
3233 after_block = False
34+
3335 if block_idx > self ._block_idx :
3436 self ._block_tensor_nums [block_idx ] = 1
3537 if block_idx != 0 :
@@ -41,9 +43,9 @@ def get_cnt(self, block_idx):
4143 # one step end
4244 self ._block_idx = block_idx
4345 self ._block_tensor_nums = {block_idx : 1 }
44-
46+
4547 offload_tensor_key = f"{ self ._block_idx } _{ self ._block_tensor_nums [self ._block_idx ] - 1 } "
46- return offload_tensor_key , after_block
48+ return offload_tensor_key , after_block , prev_block_idx
4749
4850 def get_prefetch_keys (self , block_idx , tensor_idx ):
4951 prefetch_block_idx = max ((idx for idx in self ._block_tensor_nums .keys () if idx < block_idx ), default = None )
@@ -193,11 +195,13 @@ def __init__(self, check=False):
193195 self .items = {}
194196 self .check = check
195197 self .device_item = []
196- self .getcnt = GetCnt ()
198+ self .getcnt = {}
197199 self .may_npu_tensors = {}
198200
199- def get_cnt (self , block_idx ):
200- return self .getcnt .get_cnt (block_idx )
201+ def get_cnt (self , block_idx , group = "default" ):
202+ if group not in self .getcnt :
203+ self .getcnt [group ] = GetCnt ()
204+ return self .getcnt [group ].get_cnt (block_idx )
201205
202206 def assert_exist (self , key ):
203207 if key not in self .items :
@@ -249,16 +253,17 @@ def get(self, key):
249253 self .may_npu_tensors .update ({key : self .items .pop (key )})
250254 return act
251255
252- def prefetch_get (self , block_idx , tensor_idx , h2d_stream , d2h_stream ):
253- prefetch_keys = self .getcnt .get_prefetch_keys (block_idx , tensor_idx )
256+ def prefetch_get (self , block_idx , tensor_idx , h2d_stream , d2h_stream , group = "default" ):
257+ if group not in self .getcnt :
258+ return
259+ prefetch_keys = self .getcnt [group ].get_prefetch_keys (block_idx , tensor_idx )
254260 for prefetch_key in prefetch_keys :
255- if self .exist (prefetch_key ):
256- prefetch_swap_tensor = self .get (prefetch_key )
261+ full_key = f"{ group } _{ prefetch_key } "
262+ if self .exist (full_key ):
263+ prefetch_swap_tensor = self .get (full_key )
257264 h2d_stream .wait_stream (d2h_stream )
258265 prefetch_swap_tensor .prefetch_launch_h2d (h2d_stream , True )
259266 # prefetch_swap_tensor.tensor.record_stream(h2d_stream)
260- else :
261- torch .distributed .breakpoint ()
262267
263268 def empty (self ):
264269 return len (self .items ) == 0
@@ -291,7 +296,8 @@ def __init__(
291296 h2d_stream : torch .cuda .Stream ,
292297 d2h_stream : torch .cuda .Stream ,
293298 block_idx : int ,
294- depth : int ,
299+ depth : int | None = None ,
300+ group : str = "default" ,
295301 custom_check_fn = None ,
296302 prefetch = True ,
297303 ) -> None :
@@ -302,19 +308,21 @@ def _pack_to_cpu(tensor):
302308 if (custom_check_fn is not None ) and (not custom_check_fn (tensor )):
303309 return tensor
304310
305- key , after_block = OffloadManager ().get_cnt (block_idx )
311+ key , after_block , prev_block_idx = OffloadManager ().get_cnt (block_idx , group = group )
306312
307- if after_block :
308- OffloadManager ().del_npu_tensor (f"{ block_idx - 1 } _" , d2h_stream )
313+ if after_block and ( prev_block_idx is not None ) :
314+ OffloadManager ().del_npu_tensor (f"{ group } _ { prev_block_idx } _" , d2h_stream )
309315
310316 swap_tensor = SwapTensor (tensor , key )
317+ full_key = f"{ group } _{ key } "
311318
312- if block_idx <= depth - 1 :
319+ should_offload = depth is None or block_idx <= depth - 1
320+ if should_offload :
313321 working_stream = torch .cuda .current_stream ()
314322 d2h_stream .wait_stream (working_stream )
315323 swap_tensor .launch_d2h (d2h_stream )
316324
317- OffloadManager ().put (key , swap_tensor )
325+ OffloadManager ().put (full_key , swap_tensor )
318326 return swap_tensor
319327
320328 def _unpack_from_cpu (swap_tensor ) -> torch .Tensor :
@@ -328,14 +336,14 @@ def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
328336
329337 block_idx , tensor_idx = swap_tensor .key .split ("_" )
330338
331- OffloadManager ().del_may_npu_tensor (f"{ int (block_idx ) + 1 } _" , h2d_stream )
339+ OffloadManager ().del_may_npu_tensor (f"{ group } _ { int (block_idx ) + 1 } _" , h2d_stream )
332340 swap_tensor .launch_h2d (h2d_stream , True , working_stream )
333341 # if block_idx in ["0", "2", "3"]:
334342 # if block_idx in ["0"]:
335343 # torch.cuda.synchronize()
336344
337345 if prefetch and block_idx != 0 :
338- OffloadManager ().prefetch_get (int (block_idx ), int (tensor_idx ), h2d_stream , d2h_stream )
346+ OffloadManager ().prefetch_get (int (block_idx ), int (tensor_idx ), h2d_stream , d2h_stream , group = group )
339347
340348 # if block_idx in ["0"] and tensor_idx == "1":
341349 # swap_tensor.load()
0 commit comments