@@ -53,6 +53,77 @@ class UcmSparseMetadata(ABC): # noqa: B024
5353 pass
5454
5555
56+ class UcmSparseCpuGpuBuffer :
57+ """Buffer to easily copy tensors between CPU and GPU. Inferred by vLLM."""
58+
59+ def __init__ (
60+ self ,
61+ * size : Union [int , torch .SymInt ],
62+ dtype : torch .dtype ,
63+ device : torch .device ,
64+ pin_memory : bool = True ,
65+ with_numpy : bool = True ,
66+ ) -> None :
67+ self .cpu = torch .zeros (* size , dtype = dtype , device = "cpu" , pin_memory = pin_memory )
68+ self .gpu = self .cpu .to (device )
69+ self .np : np .ndarray
70+ self .n = 0
71+
72+ if with_numpy :
73+ if dtype == torch .bfloat16 :
74+ raise ValueError (
75+ "Bfloat16 torch tensors cannot be directly cast to a "
76+ "numpy array, so call UcmSparseCpuGpuBuffer with with_numpy=False"
77+ )
78+ self .np = self .cpu .numpy ()
79+
80+ def copy_to_gpu (self , n : Optional [int ] = None ) -> None :
81+ # TODO: replace with esa_copy
82+ if n is None :
83+ n = self .n
84+ if n <= 0 :
85+ return
86+ self .gpu [:n ].copy_ (self .cpu [:n ], non_blocking = True )
87+
88+ def copy_to_cpu (self , n : Optional [int ] = None ) -> None :
89+ # TODO: replace with esa_copy
90+ """NOTE: Because this method is non-blocking, explicit synchronization
91+ is needed to ensure the data is copied to CPU."""
92+ if n is None :
93+ n = self .n
94+ if n <= 0 :
95+ return
96+ self .cpu [:n ].copy_ (self .gpu [:n ], non_blocking = True )
97+
98+ def append_numpy (self , data : List [Any ]) -> None :
99+ size = len (data )
100+ assert (
101+ self .np is not None
102+ ), "append_numpy meed to be initialized by with_numpy=True."
103+ assert self .n + size < self .cpu .shape [0 ], "append_numpy data out of range."
104+ self .np [self .n : self .n + size ] = data
105+ self .n += size
106+
107+ def clear (self ) -> None :
108+ self .n = 0
109+
110+ @property
111+ def size (self ) -> int :
112+ return self .n
113+
114+ @property
115+ def valid_np (self ) -> np .ndarray :
116+ return self .np [: self .n ]
117+
118+ @property
119+ def valid_cpu (self ) -> torch .Tensor :
120+ return self .cpu [: self .n ]
121+
122+ @property
123+ def valid_gpu (self ) -> torch .Tensor :
124+ return self .gpu [: self .n ]
125+
126+
56127class UcmSparseBase (ABC ):
57128 """
58129 An general interface for impl sparse attention algorithm in vLLM
0 commit comments