Skip to content

Commit 6879537

Browse files
committed
Fix the hash weight of HashsEncoder
1 parent 92de9f9 commit 6879537

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

docs/source/getting-started/quickstart_vllm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ You may directly edit the example file at `unified-cache-management/examples/ucm
149149

150150
### Feature 2: Sparsity
151151

152-
The sparse module was not compiled by default. To enable it, set the environment variable `export ENABLE_SPARSE=TRUE` and re-compile the code you built. And uncomment `ucm_sparse_config` code block in `unified-cache-management/examples/ucm_config_example.yaml`. Additionally, if you want to run GSAOnDevice, you also need to set the environment variable `export VLLM_HASH_ATTENTION=1`.
152+
The sparse module was not compiled by default. To enable it, set the environment variable `export ENABLE_SPARSE=TRUE` and re-compile the code you built. And uncomment `ucm_sparse_config` code block in `unified-cache-management/examples/ucm_config_example.yaml`. Additionally, if you want to run GSAOnDevice, you also need to set the environment variable `export VLLM_HASH_ATTENTION=1`, and you can fix the hash weight of the HashEncoder by setting the environment variable `export HASHENCODERSEED=<the seed number>`.
153153

154154
## Step 3: Launching Inference
155155

ucm/sparse/gsa_on_device/hash_encoder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
THE SOFTWARE.
2323
"""
2424

25+
import os
2526
import torch
2627

2728
if hasattr(torch, "npu") and torch.npu.is_available():
@@ -317,6 +318,13 @@ def __init__(
317318

318319
def _init_hash_weights(self):
319320
# Step 1: 随机高斯矩阵
321+
seed_str = os.getenv("HASHENCODERSEED")
322+
if seed_str is not None and seed_str.strip() != "":
323+
seed = int(seed_str)
324+
torch.manual_seed(seed)
325+
if self.device.type == "cuda":
326+
torch.cuda.manual_seed(seed)
327+
torch.cuda.manual_seed_all(seed)
320328
random_weights = torch.normal(
321329
mean=0,
322330
std=2,

0 commit comments

Comments
 (0)