2626import functools
2727from typing import Dict , Optional , Any
2828import numpy as np
29+ from jetstream .engine import engine_api
30+ from enum import Enum
2931
3032
3133def _get_size_of_pytree (params ):
@@ -53,20 +55,25 @@ def convert_if_np(leaf):
5355
5456 return jax .tree_util .tree_map (convert_if_np , params )
5557
58+ class AdapterStatus (str , Enum ):
59+ UNLOADED = "unloaded"
60+ LOADING = "loading"
61+ LOADED_HBM = "loaded_hbm"
62+ LOADED_CPU = "loaded_cpu"
63+
5664
5765@dataclasses .dataclass
5866class AdapterMetadata :
5967 adapter_id : str
6068 adapter_path : str
61- status : str = "unloaded" # "loaded_hbm", "loaded_cpu", "loading", "unloading"
69+ status : AdapterStatus = AdapterStatus . UNLOADED
6270 size_hbm : int = 0 # Size in HBM (bytes)
6371 size_cpu : int = 0 # Size in CPU RAM (bytes)
6472 last_accessed : float = 0.0 # timestamp
6573 config : Dict [str , Any ] = dataclasses .field (default_factory = dict )
6674
6775
6876class AdapterTensorStore :
69- def __init__ (self , hbm_memory_budget : int , cpu_memory_budget : int , total_slots : int ):
7077 """
7178 Manages the storage and retrieval of LoRA adapter weights, handling
7279 placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM.
@@ -87,6 +94,14 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots:
8794 total_slots: Number of generate slots. This is also equals to max_concurrent_decodes.
8895 """
8996
97+ def __init__ (self ,
98+ engine : engine_api .Engine ,
99+ adapters_dir_path : str ,
100+ hbm_memory_budget : int ,
101+ cpu_memory_budget : int ):
102+ """Initializes the AdapterTensorStore."""
103+ self .engine = engine # Possibly MaxEngine object
104+ self .adapters_dir_path = adapters_dir_path .rstrip ("/" ) # All Adapters path without trailing `/`
90105 self .hbm_memory_budget = hbm_memory_budget
91106 self .cpu_memory_budget = cpu_memory_budget
92107 self .adapter_registry : Dict [str , AdapterMetadata ] = {} # All known adapters
@@ -100,26 +115,49 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots:
100115 self .lock = asyncio .Lock () # Use an asyncio Lock for thread safety
101116
102117
103- def register_adapter (self , adapter_id : str , adapter_path : str , config : Dict [str , Any ]):
118+ def register_adapter (self ,
119+ adapter_id : str ,
120+ adapter_path : str | None = None ,
121+ adapter_config : Dict [str , Any ] | None = None ):
104122 """Registers a new LoRA adatper."""
105123 """
106- Registers a LoRA adapter with the TensorStore. This does *not* load
107- the adapter; it simply adds metadata about the adapter to the registry.
124+ Registers a LoRA adapter with the TensorStore. This also loads the adapter;
125+ IF called without adapter_config. Because in this case, it needs
126+ to get adapter_config from the engine's load_single_adapter() call, which
127+ also provides the adapter_params. So in that case it is beneficial to load
128+ the adapter to HBM. This call path is expected only from the direct inference
129+ request.
130+ OTHERWISE, it simply adds metadata about the adapter to the registry.
108131
109132 Args:
110133 adapter_id (str): A unique identifier for the adapter.
111134 adapter_path (str): The path to the adapter weights (file or directory).
112- config (dict): Config of the loRA adapter.
135+ adapter_config (dict): Config of the loRA adapter.
113136
114137 Raises:
115138 ValueError: If an adapter with the same ID is already registered.
116139 """
117140 if adapter_id in self .adapter_registry :
118- raise ValueError (f"Adapter with ID '{ adapter_id } ' already registered." )
141+ logging .warning (f"Adapter with ID '{ adapter_id } ' already registered." )
142+ return
143+
144+ if adapter_path is None :
145+ adapter_path = f"{ self .adapters_dir_path } /{ adapter_id } "
146+
147+ adapter_params = None
148+ if adapter_config is None :
149+ adapter_params , adapter_config = self .engine .load_single_adapter (adapter_path )
150+
151+ if adapter_config is None :
152+ raise ValueError (f"Failed to read adapter_config from { adapter_path } " )
153+
119154 self .adapter_registry [adapter_id ] = AdapterMetadata (
120155 adapter_id = adapter_id ,
121156 adapter_path = adapter_path ,
122- config = config )
157+ config = adapter_config )
158+
159+ if adapter_params is not None :
160+ asyncio .run (self .load_adapter (adapter_id , adapter_params , True ))
123161
124162
125163 async def _transfer_to_hbm (self , adapter_id : str ):
@@ -130,7 +168,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
130168 async with self .lock : #Acquire lock
131169 metadata = self .adapter_registry [adapter_id ]
132170
133- if metadata .status == "loaded_hbm" :
171+ if metadata .status == AdapterStatus . LOADED_HBM :
134172 return
135173
136174 # Check if we have enough space in HBM; evict if necessary
@@ -147,7 +185,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
147185 self .current_cpu_usage -= metadata .size_cpu
148186 self .current_hbm_usage += metadata .size_hbm
149187
150- metadata .status = "loaded_hbm"
188+ metadata .status = AdapterStatus . LOADED_HBM
151189 metadata .last_accessed = time .time ()
152190
153191
@@ -160,7 +198,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
160198 async with self .lock :
161199 metadata = self . adapter_registry [adapter_id ]
162200
163- if metadata .status == "loaded_cpu" :
201+ if metadata .status == AdapterStatus . LOADED_CPU :
164202 return
165203
166204 # Check if we have enough space in CPU; evict if necessary.
@@ -175,7 +213,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
175213 self .current_hbm_usage -= metadata .size_hbm
176214 self .current_cpu_usage += metadata .size_cpu
177215
178- metadata .status = "loaded_cpu"
216+ metadata .status = AdapterStatus . LOADED_CPU
179217 metadata .last_accessed = time .time ()
180218
181219
@@ -256,7 +294,7 @@ async def get_hbm_loaded_adapters(self):
256294
257295 async with self .lock :
258296 for adapter_id , metadata in self .adapter_registry .items ():
259- if metadata .status == "loaded_hbm" :
297+ if metadata .status == AdapterStatus . LOADED_HBM :
260298 hbm_loaded_adapters .append (adapter_id )
261299
262300 return ", " .join (hbm_loaded_adapters )
@@ -295,41 +333,45 @@ async def load_adapter(
295333 metadata = self .adapter_registry [adapter_id ]
296334
297335 async with self .lock : # Acquire lock for thread safety
298- if metadata .status in ("loaded_hbm" , "loaded_cpu" ):
336+ if metadata .status in (AdapterStatus . LOADED_HBM , AdapterStatus . LOADED_CPU ):
299337 metadata .last_accessed = time .time ()
300338
301339 # if already loaded in HBM and we want HBM, or
302340 # already loaded in CPU and we want CPU, we're done.
303- if ((to_hbm and metadata .status == "loaded_hbm" ) or
304- not to_hbm and metadata .status == "loaded_cpu" ):
341+ if ((to_hbm and metadata .status == AdapterStatus . LOADED_HBM ) or
342+ not to_hbm and metadata .status == AdapterStatus . LOADED_CPU ):
305343 return
306- elif to_hbm and metadata .status == "loaded_cpu" :
344+ elif to_hbm and metadata .status == AdapterStatus . LOADED_CPU :
307345 # Transfer from cpu to hbm
308346 self ._transfer_to_hbm (adapter_id )
309347 return
310- elif not to_hbm and metadata .status == "loaded_hbm" :
348+ elif not to_hbm and metadata .status == AdapterStatus . LOADED_HBM :
311349 # Transfer from hbm to cpu
312350 self ._transfer_to_cpu (adapter_id )
313351 return
314352
315- if metadata .status == "loading" :
353+ if metadata .status == AdapterStatus . LOADING :
316354 # Wait untill loading is done.
317- while metadata .status == "loading" :
355+ while metadata .status == AdapterStatus . LOADING :
318356 await asyncio .sleep (0.1 ) # Short sleep to avoid busy-waiting
319357
320358 # Make recursive call to load_adapter to copy to device
321359 await self .load_adapter (adapter_id , adapter_weights , to_hbm )
322360 return
323361
324- metadata .status = "loading"
362+ metadata .status = AdapterStatus . LOADING
325363 self .running_requests += 1
326364
327365 # Load the adapter (asynchronous)
328366 loop = asyncio .get_running_loop ()
329367
330368 try :
331369 if adapter_weights is None :
332- raise ValueError ("Adapter weights for adapter_id={adapter_id} is None." )
370+ adapter_path = f"{ self .adapters_dir_path } /{ adapter_id } "
371+ adapter_weights , adapter_config = self .engine .load_single_adapter (adapter_path )
372+
373+ if adapter_weights is None :
374+ raise ValueError ("Failed to load adapter_weights from {adapter_path}." )
333375
334376 async with self .lock : # Critical section for memory management
335377 adapter_weights_as_jnp_array = _as_jnp_array (adapter_weights )
@@ -360,45 +402,60 @@ async def load_adapter(
360402 if to_hbm :
361403 self .loaded_adapters_hbm [adapter_id ] = adapter_weights_as_jnp_array # Convert the PyTree to Jax Array
362404 self .current_hbm_usage += adapter_size_hbm
363- metadata .status = "loaded_hbm"
405+ metadata .status = AdapterStatus . LOADED_HBM
364406
365407 else : #to cpu
366408 self .loaded_adapters_cpu [adapter_id ] = adapter_weights_as_np_array # Convert the PyTree to NumPy Array
367409 self .current_cpu_usage += adapter_size_cpu
368- metadata .status = "loaded_cpu"
410+ metadata .status = AdapterStatus . LOADED_CPU
369411
370412 metadata .last_accessed = time .time ()
371413
372414 except Exception as e :
373415 async with self .lock :
374- metadata .status = "unloaded" # Mark as unloaded on error
416+ metadata .status = AdapterStatus . UNLOADED # Mark as unloaded on error
375417 raise e # Re-Raise the exception
376418 finally :
377419 async with self .lock :
378420 self .running_requests -= 1
379421
380422
381- def get_lora_config (self , adapter_id ):
423+ def get_lora_config (self , adapter_id : str , load_if_not_loaded : bool = False ):
382424 """Getter for the LoRA adapter config."""
383425 metadata = self .adapter_registry .get (adapter_id )
426+
427+ if load_if_not_loaded and metadata is None :
428+ self .register_adapter (adapter_id )
429+ metadata = self .adapter_registry .get (adapter_id )
430+
431+ if metadata is None :
432+ raise ValueError (f"LoRA adapter with id={ adapter_id } is not loaded." )
433+
384434 return metadata .config
385435
386436
387- def get_lora_weights (self , adapter_id , to_hbm : bool = True ):
437+ def get_lora_weights (self ,
438+ adapter_id ,
439+ to_hbm : bool = True ,
440+ load_if_not_loaded : bool = False ):
388441 """Retrieves the unified LoRA parameters for the given adapter IDs.
389442 Handles HBM/CPU placement.
390443 """
391444
392445 metadata = self .adapter_registry .get (adapter_id )
393446
447+ if load_if_not_loaded and metadata is None :
448+ self .register_adapter (adapter_id )
449+ metadata = self .adapter_registry .get (adapter_id )
450+
394451 if metadata is None :
395- raise ValueError (f"Adapter with ID ' { adapter_id } ' not registered ." )
452+ raise ValueError (f"LoRA adapter with id= { adapter_id } is not loaded ." )
396453
397- if metadata .status != "loaded_hbm" and metadata .status != "loaded_cpu" :
454+ if metadata .status != AdapterStatus . LOADED_HBM and metadata .status != AdapterStatus . LOADED_CPU :
398455 asyncio .run (self .load_adapter (adapter_id , None , to_hbm )) # Start loading (async)
399- elif to_hbm and metadata .status == "loaded_cpu" :
456+ elif to_hbm and metadata .status == AdapterStatus . LOADED_CPU :
400457 asyncio .run (self ._transfer_to_hbm (adapter_id ))
401- elif not to_hbm and metadata .status == "loaded_hbm" :
458+ elif not to_hbm and metadata .status == AdapterStatus . LOADED_HBM :
402459 asyncio .run (self ._transfer_to_cpu (adapter_id ))
403460
404461 # Wait till all the running requests are completed
@@ -423,21 +480,21 @@ async def unload_adapter(self, adapter_id: str):
423480 metadata = self .adapter_registry [adapter_id ]
424481
425482 async with self .lock :
426- if metadata .status == "unloaded" :
483+ if metadata .status == AdapterStatus . UNLOADED :
427484 return # Already unloaded
428- if metadata .status == "loading" :
485+ if metadata .status == AdapterStatus . LOADING :
429486 # Wait for the loading to get complete.
430- while metadata .status == "loading" :
487+ while metadata .status == AdapterStatus . LOADING :
431488 await asyncio .sleep (0.1 )
432489
433- if metadata .status == "loaded_hbm" :
490+ if metadata .status == AdapterStatus . LOADED_HBM :
434491 del self .loaded_adapters_hbm [adapter_id ]
435492 self .current_hbm_usage -= metadata .size_hbm
436- metadata .status = "unloaded"
437- elif metadata .status == "loaded_cpu" :
493+ metadata .status = AdapterStatus . UNLOADED
494+ elif metadata .status == AdapterStatus . LOADED_CPU :
438495 del self .loaded_adapters_cpu [adapter_id ]
439496 self .current_cpu_usage -= metadata .size_cpu
440- metadata .status = "unloaded"
497+ metadata .status = AdapterStatus . UNLOADED
441498
442499 metadata .last_accessed = time .time () # Unload time
443500 metadata .size_hbm = 0
@@ -457,7 +514,7 @@ def _evict(self, from_hbm: bool = True) -> bool:
457514 lru_time = float ('inf' )
458515
459516 for adapter_id , metadata in self .adapter_registry .items ():
460- if metadata .status == "loaded_hbm" if from_hbm else metadata .status == "loaded_cpu" :
517+ if metadata .status == AdapterStatus . LOADED_HBM if from_hbm else metadata .status == AdapterStatus . LOADED_CPU :
461518 if metadata .last_accessed < lru_time :
462519 lru_time = metadata .last_accessed
463520 lru_adapter_id = adapter_id
0 commit comments