3232import torch
3333
3434from ucm .store .nfsstore .nfsstore_connector import UcmNfsStore
35+ from ucm .store .pcstore .pcstore_connector_v1 import UcmPcStoreV1
3536from ucm .store .ucmstore import UcmKVStoreBase
37+ from ucm .store .ucmstore_v1 import UcmKVStoreBaseV1
3638
3739
3840def setup (
@@ -42,7 +44,7 @@ def setup(
4244 io_size ,
4345 transferStreamNumber ,
4446 transferIoDirect ,
45- ) -> UcmKVStoreBase :
47+ ) -> UcmKVStoreBaseV1 :
4648 config = {
4749 "storage_backends" : storage_backends ,
4850 "kv_block_size" : block_size ,
@@ -51,8 +53,9 @@ def setup(
5153 "io_size" : io_size ,
5254 "transferStreamNumber" : transferStreamNumber ,
5355 "transferIoDirect" : transferIoDirect ,
56+ "unique_id" : secrets .token_hex (8 ),
5457 }
55- return UcmNfsStore (config )
58+ return UcmPcStoreV1 (config )
5659
5760
5861def make_aligned_tensor (shape , dtype , device , alignment = 4096 ):
@@ -79,64 +82,59 @@ def make_aligned_tensor(shape, dtype, device, alignment=4096):
7982def make_buffers (
8083 block_number , device_id , batch_size , head_dim , block_len , block_layer , num_head , kv
8184):
82- hashes = [secrets .token_hex (16 ) for _ in range (block_number )]
83- kv_caches = {}
84- for i in range (block_layer ):
85- kv_caches [ i ] = make_aligned_tensor (
85+ hashes = [secrets .token_bytes (16 ) for _ in range (block_number )]
86+ kvcaches = {}
87+ for layer_id in range (block_layer ):
88+ kvcaches [ layer_id ] = make_aligned_tensor (
8689 [kv , block_number , block_len , num_head , head_dim ],
87- dtype = torch .float16 ,
90+ dtype = torch .bfloat16 ,
8891 device = f"cuda:{ device_id } " ,
8992 )
90- return hashes , kv_caches
93+ kvcaches [layer_id ].random_ ()
94+ return hashes , kvcaches
9195
9296
93- def store_all_hashes (hashes : List [str ]):
97+ def store_all_hashes (hashes : List [bytes ]):
9498 file_path = os .path .join (os .path .dirname (__file__ ), "kvcache_block_hashes.txt" )
9599 with open (file_path , "w" , encoding = "utf-8" ) as f :
96100 for h in hashes :
97- f .write (h + "\n " )
101+ f .write (h . hex () + "\n " )
98102
99103
100- def load_hashes_from_file () -> List [str ]:
104+ def load_hashes_from_file () -> List [bytes ]:
101105 file_path = os .path .join (os .path .dirname (__file__ ), "kvcache_block_hashes.txt" )
102106 if not os .path .exists (file_path ):
103107 return []
104108 with open (file_path , "r" , encoding = "utf-8" ) as f :
105- return [line .strip () for line in f .readlines ()]
109+ return [bytes . fromhex ( line .strip () ) for line in f .readlines ()]
106110
107111
108112def embed (
109- store : UcmKVStoreBase ,
110- hashes : List [str ],
113+ store : UcmKVStoreBaseV1 ,
114+ hashes : List [bytes ],
111115 kvcaches : Dict [int , torch .Tensor ],
112116 mla : bool ,
113117):
114118 start_time = time .perf_counter ()
115119
116- total_block_ids , total_offsets , total_tensors = [], [], []
120+ total_tensors = []
117121 total_size = 0
118122
119123 for i , hash_val in enumerate (hashes ):
120- offset = 0
124+ tensors = []
121125 for layer_id , kv_layer in kvcaches .items ():
122- k_tensor = kv_layer [0 ][i ] # kv=1
123- total_tensors .append (k_tensor )
124- total_block_ids .append (hash_val )
125- total_offsets .append (offset )
126+ k_tensor = kv_layer [0 ][i ].contiguous ()
127+ tensors .append (k_tensor )
126128 sz = k_tensor .numel () * k_tensor .element_size ()
127- offset += sz
128129 total_size += sz
129130
130131 if not mla :
131- v_tensor = kv_layer [1 ][i ]
132- total_tensors .append (v_tensor )
133- total_block_ids .append (hash_val )
134- total_offsets .append (offset )
132+ v_tensor = kv_layer [1 ][i ].contiguous ()
133+ tensors .append (v_tensor )
135134 sz = v_tensor .numel () * v_tensor .element_size ()
136- offset += sz
137135 total_size += sz
138-
139- task = store .dump (total_block_ids , total_offsets , total_tensors )
136+ total_tensors . append ( tensors )
137+ task = store .dump (hashes , [] , total_tensors )
140138 store .wait (task )
141139
142140 elapsed_time = time .perf_counter () - start_time
@@ -151,8 +149,8 @@ def embed(
151149
152150
153151def fetch (
154- store : UcmKVStoreBase ,
155- hashes : List [str ],
152+ store : UcmKVStoreBaseV1 ,
153+ hashes : List [bytes ],
156154 kvcaches : Dict [int , torch .Tensor ],
157155 mla : bool ,
158156):
@@ -162,32 +160,33 @@ def fetch(
162160 for f in founds :
163161 assert f , "Cache block miss detected"
164162
165- block_ids , offsets , tensors = [], [], []
163+ totoal_tensors = []
166164 total_size = 0
167165
168166 for i , hash_val in enumerate (hashes ):
169- offset = 0
167+ tensors = []
170168 for layer_id , kv_layer in kvcaches .items ():
171- k_tensor = kv_layer [0 ][i ] # kv=1
172- block_ids .append (hash_val )
173- offsets .append (offset )
169+ k_tensor = kv_layer [0 ][i ].contiguous ()
174170 tensors .append (k_tensor )
175171 sz = k_tensor .numel () * k_tensor .element_size ()
176- offset += sz
177172 total_size += sz
178173
179174 if not mla :
180- v_tensor = kv_layer [1 ][i ]
181- block_ids .append (hash_val )
182- offsets .append (offset )
175+ v_tensor = kv_layer [1 ][i ].contiguous ()
183176 tensors .append (v_tensor )
184177 sz = v_tensor .numel () * v_tensor .element_size ()
185- offset += sz
186178 total_size += sz
179+ totoal_tensors .append (tensors )
187180
188- task = store .load (block_ids , offsets , tensors )
189- ret = store .wait (task )
190- assert ret == 0 , "Load operation failed"
181+ task = store .load (hashes , [], totoal_tensors )
182+ try :
183+ ret = store .wait (task )
184+ if ret is None :
185+ ret = 0
186+ except RuntimeError as e :
187+ print (f"Load operation failed with error: { e } " )
188+ raise
189+ assert ret == 0 , f"Load operation failed with return code: { ret } "
191190
192191 elapsed_time = time .perf_counter () - start_time
193192 throughput_gbps = (total_size / (1024 ** 3 )) / elapsed_time if elapsed_time > 0 else 0
@@ -226,6 +225,10 @@ def run(
226225 block_dim = head_size * num_head
227226 io_size = block_dim * block_len * block_elem_size
228227 block_size = io_size * block_layer
228+
229+ if not mla :
230+ block_size = block_size * 2
231+
229232 batch_size = int (num_tokens / block_len )
230233 real_blocks = batch_size + 10
231234
@@ -257,16 +260,13 @@ def run(
257260 kv ,
258261 )
259262
260- results = store .create (hashes [:batch_size ])
261- assert sum (results ) == 0 , "Create operation failed"
262-
263263 w_size , w_time , w_bw = embed (
264264 store ,
265265 hashes [:batch_size ],
266266 kvcaches ,
267267 mla ,
268268 )
269- store . commit ( hashes [: batch_size ], True )
269+ time . sleep ( 1 )
270270
271271 if r == 0 :
272272 store_all_hashes (hashes [:batch_size ])
@@ -349,10 +349,10 @@ def run(
349349 try :
350350 result = run (
351351 storage_backends = "." ,
352- device_id = 1 ,
353- repeat = 1 ,
352+ device_id = 6 ,
353+ repeat = 2 ,
354354 num_head = 1 ,
355- block_len = 128 ,
355+ block_len = 64 ,
356356 transferStreamNumber = 32 ,
357357 num_tokens = 4096 ,
358358 block_layer = 61 ,
0 commit comments