2727from typing import Dict , Optional , Any
2828import numpy as np
2929from jetstream .engine import engine_api
30+ from enum import Enum
3031
3132
3233def _get_size_of_pytree (params ):
@@ -54,12 +55,18 @@ def convert_if_np(leaf):
5455
5556 return jax .tree_util .tree_map (convert_if_np , params )
5657
58+ class AdapterStatus (str , Enum ):
59+ UNLOADED = "unloaded"
60+ LOADING = "loading"
61+ LOADED_HBM = "loaded_hbm"
62+ LOADED_CPU = "loaded_cpu"
63+
5764
5865@dataclasses .dataclass
5966class AdapterMetadata :
6067 adapter_id : str
6168 adapter_path : str
62- status : str = "unloaded" # "loaded_hbm", "loaded_cpu", "loading", "unloading"
69+ status : AdapterStatus = AdapterStatus . UNLOADED
6370 size_hbm : int = 0 # Size in HBM (bytes)
6471 size_cpu : int = 0 # Size in CPU RAM (bytes)
6572 last_accessed : float = 0.0 # timestamp
@@ -155,7 +162,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
155162 async with self .lock : #Acquire lock
156163 metadata = self .adapter_registry [adapter_id ]
157164
158- if metadata .status == "loaded_hbm" :
165+ if metadata .status == AdapterStatus . LOADED_HBM :
159166 return
160167
161168 # Check if we have enough space in HBM; evict if necessary
@@ -172,7 +179,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
172179 self .current_cpu_usage -= metadata .size_cpu
173180 self .current_hbm_usage += metadata .size_hbm
174181
175- metadata .status = "loaded_hbm"
182+ metadata .status = AdapterStatus . LOADED_HBM
176183 metadata .last_accessed = time .time ()
177184
178185
@@ -185,7 +192,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
185192 async with self .lock :
186193 metadata = self . adapter_registry [adapter_id ]
187194
188- if metadata .status == "loaded_cpu" :
195+ if metadata .status == AdapterStatus . LOADED_CPU :
189196 return
190197
191198 # Check if we have enough space in CPU; evict if necessary.
@@ -200,7 +207,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
200207 self .current_hbm_usage -= metadata .size_hbm
201208 self .current_cpu_usage += metadata .size_cpu
202209
203- metadata .status = "loaded_cpu"
210+ metadata .status = AdapterStatus . LOADED_CPU
204211 metadata .last_accessed = time .time ()
205212
206213
@@ -211,7 +218,7 @@ async def get_hbm_loaded_adapters(self):
211218
212219 async with self .lock :
213220 for adapter_id , metadata in self .adapter_registry .items ():
214- if metadata .status == "loaded_hbm" :
221+ if metadata .status == AdapterStatus . LOADED_HBM :
215222 hbm_loaded_adapters .append (adapter_id )
216223
217224 return ", " .join (hbm_loaded_adapters )
@@ -250,33 +257,33 @@ async def load_adapter(
250257 metadata = self .adapter_registry [adapter_id ]
251258
252259 async with self .lock : # Acquire lock for thread safety
253- if metadata .status in ("loaded_hbm" , "loaded_cpu" ):
260+ if metadata .status in (AdapterStatus . LOADED_HBM , AdapterStatus . LOADED_CPU ):
254261 metadata .last_accessed = time .time ()
255262
256263 # if already loaded in HBM and we want HBM, or
257264 # already loaded in CPU and we want CPU, we're done.
258- if ((to_hbm and metadata .status == "loaded_hbm" ) or
259- not to_hbm and metadata .status == "loaded_cpu" ):
265+ if ((to_hbm and metadata .status == AdapterStatus . LOADED_HBM ) or
266+ not to_hbm and metadata .status == AdapterStatus . LOADED_CPU ):
260267 return
261- elif to_hbm and metadata .status == "loaded_cpu" :
268+ elif to_hbm and metadata .status == AdapterStatus . LOADED_CPU :
262269 # Transfer from cpu to hbm
263270 self ._transfer_to_hbm (adapter_id )
264271 return
265- elif not to_hbm and metadata .status == "loaded_hbm" :
272+ elif not to_hbm and metadata .status == AdapterStatus . LOADED_HBM :
266273 # Transfer from hbm to cpu
267274 self ._transfer_to_cpu (adapter_id )
268275 return
269276
270- if metadata .status == "loading" :
277+ if metadata .status == AdapterStatus . LOADING :
271278 # Wait untill loading is done.
272- while metadata .status == "loading" :
279+ while metadata .status == AdapterStatus . LOADING :
273280 await asyncio .sleep (0.1 ) # Short sleep to avoid busy-waiting
274281
275282 # Make recursive call to load_adapter to copy to device
276283 await self .load_adapter (adapter_id , adapter_weights , to_hbm )
277284 return
278285
279- metadata .status = "loading"
286+ metadata .status = AdapterStatus . LOADING
280287 self .running_requests += 1
281288
282289 # Load the adapter (asynchronous)
@@ -319,18 +326,18 @@ async def load_adapter(
319326 if to_hbm :
320327 self .loaded_adapters_hbm [adapter_id ] = adapter_weights_as_jnp_array # Convert the PyTree to Jax Array
321328 self .current_hbm_usage += adapter_size_hbm
322- metadata .status = "loaded_hbm"
329+ metadata .status = AdapterStatus . LOADED_HBM
323330
324331 else : #to cpu
325332 self .loaded_adapters_cpu [adapter_id ] = adapter_weights_as_np_array # Convert the PyTree to NumPy Array
326333 self .current_cpu_usage += adapter_size_cpu
327- metadata .status = "loaded_cpu"
334+ metadata .status = AdapterStatus . LOADED_CPU
328335
329336 metadata .last_accessed = time .time ()
330337
331338 except Exception as e :
332339 async with self .lock :
333- metadata .status = "unloaded" # Mark as unloaded on error
340+ metadata .status = AdapterStatus . UNLOADED # Mark as unloaded on error
334341 raise e # Re-Raise the exception
335342 finally :
336343 async with self .lock :
@@ -368,11 +375,11 @@ def get_lora_weights(self,
368375 if metadata is None :
369376 raise ValueError (f"LoRA adapter with id={ adapter_id } is not loaded." )
370377
371- if metadata .status != "loaded_hbm" and metadata .status != "loaded_cpu" :
378+ if metadata .status != AdapterStatus . LOADED_HBM and metadata .status != AdapterStatus . LOADED_CPU :
372379 asyncio .run (self .load_adapter (adapter_id , None , to_hbm )) # Start loading (async)
373- elif to_hbm and metadata .status == "loaded_cpu" :
380+ elif to_hbm and metadata .status == AdapterStatus . LOADED_CPU :
374381 asyncio .run (self ._transfer_to_hbm (adapter_id ))
375- elif not to_hbm and metadata .status == "loaded_hbm" :
382+ elif not to_hbm and metadata .status == AdapterStatus . LOADED_HBM :
376383 asyncio .run (self ._transfer_to_cpu (adapter_id ))
377384
378385 # Wait till all the running requests are completed
@@ -397,21 +404,21 @@ async def unload_adapter(self, adapter_id: str):
397404 metadata = self .adapter_registry [adapter_id ]
398405
399406 async with self .lock :
400- if metadata .status == "unloaded" :
407+ if metadata .status == AdapterStatus . UNLOADED :
401408 return # Already unloaded
402- if metadata .status == "loading" :
409+ if metadata .status == AdapterStatus . LOADING :
403410 # Wait for the loading to get complete.
404- while metadata .status == "loading" :
411+ while metadata .status == AdapterStatus . LOADING :
405412 await asyncio .sleep (0.1 )
406413
407- if metadata .status == "loaded_hbm" :
414+ if metadata .status == AdapterStatus . LOADED_HBM :
408415 del self .loaded_adapters_hbm [adapter_id ]
409416 self .current_hbm_usage -= metadata .size_hbm
410- metadata .status = "unloaded"
411- elif metadata .status == "loaded_cpu" :
417+ metadata .status = AdapterStatus . UNLOADED
418+ elif metadata .status == AdapterStatus . LOADED_CPU :
412419 del self .loaded_adapters_cpu [adapter_id ]
413420 self .current_cpu_usage -= metadata .size_cpu
414- metadata .status = "unloaded"
421+ metadata .status = AdapterStatus . UNLOADED
415422
416423 metadata .last_accessed = time .time () # Unload time
417424 metadata .size_hbm = 0
@@ -431,7 +438,7 @@ def _evict(self, from_hbm: bool = True) -> bool:
431438 lru_time = float ('inf' )
432439
433440 for adapter_id , metadata in self .adapter_registry .items ():
434- if metadata .status == "loaded_hbm" if from_hbm else metadata .status == "loaded_cpu" :
441+ if metadata .status == AdapterStatus . LOADED_HBM if from_hbm else metadata .status == AdapterStatus . LOADED_CPU :
435442 if metadata .last_accessed < lru_time :
436443 lru_time = metadata .last_accessed
437444 lru_adapter_id = adapter_id
0 commit comments