Skip to content

Commit 2cc678f

Browse files
[gsaonDevice] bugfix & optimization
1 parent b1e95c6 commit 2cc678f

3 files changed

Lines changed: 356 additions & 69 deletions

File tree

examples/offline_inference_kvcomphbm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
7777
},
7878
}
7979
],
80-
"ucm_sparse_config": {"KvCompOnDevice": {}},
80+
"ucm_sparse_config": {"GSAOnDevice": {}},
8181
},
8282
)
8383

ucm/sparse/base.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,77 @@ class UcmSparseMetadata(ABC): # noqa: B024
5353
pass
5454

5555

56+
class UcmSparseCpuGpuBuffer:
57+
"""Buffer to easily copy tensors between CPU and GPU. Inferred by vLLM."""
58+
59+
def __init__(
60+
self,
61+
*size: Union[int, torch.SymInt],
62+
dtype: torch.dtype,
63+
device: torch.device,
64+
pin_memory: bool = True,
65+
with_numpy: bool = True,
66+
) -> None:
67+
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
68+
self.gpu = self.cpu.to(device)
69+
self.np: np.ndarray
70+
self.n = 0
71+
72+
if with_numpy:
73+
if dtype == torch.bfloat16:
74+
raise ValueError(
75+
"Bfloat16 torch tensors cannot be directly cast to a "
76+
"numpy array, so call UcmSparseCpuGpuBuffer with with_numpy=False"
77+
)
78+
self.np = self.cpu.numpy()
79+
80+
def copy_to_gpu(self, n: Optional[int] = None) -> None:
81+
# TODO: replace with esa_copy
82+
if n is None:
83+
n = self.n
84+
if n <= 0:
85+
return
86+
self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
87+
88+
def copy_to_cpu(self, n: Optional[int] = None) -> None:
89+
# TODO: replace with esa_copy
90+
"""NOTE: Because this method is non-blocking, explicit synchronization
91+
is needed to ensure the data is copied to CPU."""
92+
if n is None:
93+
n = self.n
94+
if n <= 0:
95+
return
96+
self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
97+
98+
def append_numpy(self, data: List[Any]) -> None:
99+
size = len(data)
100+
assert (
101+
self.np is not None
102+
), "append_numpy meed to be initialized by with_numpy=True."
103+
assert self.n + size < self.cpu.shape[0], "append_numpy data out of range."
104+
self.np[self.n : self.n + size] = data
105+
self.n += size
106+
107+
def clear(self) -> None:
108+
self.n = 0
109+
110+
@property
111+
def size(self) -> int:
112+
return self.n
113+
114+
@property
115+
def valid_np(self) -> np.ndarray:
116+
return self.np[: self.n]
117+
118+
@property
119+
def valid_cpu(self) -> torch.Tensor:
120+
return self.cpu[: self.n]
121+
122+
@property
123+
def valid_gpu(self) -> torch.Tensor:
124+
return self.gpu[: self.n]
125+
126+
56127
class UcmSparseBase(ABC):
57128
"""
58129
An general interface for impl sparse attention algorithm in vLLM

0 commit comments

Comments
 (0)