@@ -67,11 +67,31 @@ class AdapterMetadata:
6767
6868class AdapterTensorStore :
6969 def __init__ (self , hbm_memory_budget : int , cpu_memory_budget : int , total_slots : int ):
70+ """
71+ Manages the storage and retrieval of LoRA adapter weights, handling
72+ placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM.
73+
74+ This class implements an LRU (Least Recently Used) eviction policy
75+ to manage memory usage. It supports asynchronous loading and unloading
76+ of adapters to avoid blocking the main inference thread.
77+
78+ This class also creates a unified_lora_weights of all the adapters which is being
79+ used at any time for decoding purposes. These unified weights allows the backend
80+ model to server multiple different LoRA adapters in a single batch.
81+
82+ Args:
83+ hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for
84+ storing LoRA adapter weights.
85+ cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use
86+ for storing LoRA adapter weights.
87+ total_slots: Number of generate slots. This is also equals to max_concurrent_decodes.
88+ """
89+
7090 self .hbm_memory_budget = hbm_memory_budget
7191 self .cpu_memory_budget = cpu_memory_budget
7292 self .adapter_registry : Dict [str , AdapterMetadata ] = {} # All known adapters
73- self .loaded_adapters_hbm : Dict [str , jnp .ndarray ] = {} # adapter_id -> Unified LoRA params (in HBM)
74- self .loaded_adapters_cpu : Dict [str , np .ndarray ] = {} # adapter_id -> Unified LoRA params (in CPU RAM)
93+ self .loaded_adapters_hbm : Dict [str , jnp .ndarray ] = {} # adapter_id -> LoRA params (in HBM)
94+ self .loaded_adapters_cpu : Dict [str , np .ndarray ] = {} # adapter_id -> LoRA params (in CPU RAM)
7595 self .current_hbm_usage : int = 0
7696 self .current_cpu_usage : int = 0
7797 self .running_requests : int = 0 # Number of async tasks which are in "loading" state
@@ -82,6 +102,18 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots:
82102
83103 def register_adapter (self , adapter_id : str , adapter_path : str , config : Dict [str , Any ]):
84104 """Registers a new LoRA adatper."""
105+ """
106+ Registers a LoRA adapter with the TensorStore. This does *not* load
107+ the adapter; it simply adds metadata about the adapter to the registry.
108+
109+ Args:
110+ adapter_id (str): A unique identifier for the adapter.
111+ adapter_path (str): The path to the adapter weights (file or directory).
112+ config (dict): Config of the loRA adapter.
113+
114+ Raises:
115+ ValueError: If an adapter with the same ID is already registered.
116+ """
85117 if adapter_id in self .adapter_registry :
86118 raise ValueError (f"Adapter with ID '{ adapter_id } ' already registered." )
87119 self .adapter_registry [adapter_id ] = AdapterMetadata (
@@ -234,17 +266,36 @@ async def load_adapter(
234266 self ,
235267 adapter_id : str ,
236268 adapter_weights = None ,
237- to_hbm : bool = True ,
238- force_load : bool = False ):
239- """Loads a LoRA adapter's weights, managing HBM and CPU memory."""
269+ to_hbm : bool = True ):
270+ """
271+ Loads a LoRA adapter's weights into memory (either HBM or CPU RAM).
272+
273+ This method is asynchronous to avoid blocking the main thread during
274+ potentially slow I/O operations. It handles:
275+ - Checking if the adapter is already loaded.
276+ - Checking if there's enough memory (and evicting if necessary).
277+ - Loading the weights (in a separate thread).
278+ - Updating the adapter's status and metadata.
279+
280+ Args:
281+ adapter_id (str): The ID of the adapter to load.
282+ adapter_weights: In the form of a PyTree.
283+ to_hbm (bool): Whether to load the adapter into HBM (True) or
284+ CPU RAM (False). Defaults to True (HBM).
285+
286+ Raises:
287+ ValueError: If the adapter ID is not registered.
288+ RuntimeError: If there is not enough memory to load the adapter,
289+ and eviction fails to free up enough space.
290+ """
240291
241292 if adapter_id not in self .adapter_registry :
242293 raise ValueError (f"Adapter with ID '{ adapter_id } ' not registered." )
243294
244295 metadata = self .adapter_registry [adapter_id ]
245296
246297 async with self .lock : # Acquire lock for thread safety
247- if not force_load and metadata .status in ("loaded_hbm" , "loaded_cpu" ):
298+ if metadata .status in ("loaded_hbm" , "loaded_cpu" ):
248299 metadata .last_accessed = time .time ()
249300
250301 # if already loaded in HBM and we want HBM, or
@@ -267,7 +318,7 @@ async def load_adapter(
267318 await asyncio .sleep (0.1 ) # Short sleep to avoid busy-waiting
268319
269320 # Make recursive call to load_adapter to copy to device
270- await self .load_adapter (adapter_id , adapter_weights , to_hbm , force_load )
321+ await self .load_adapter (adapter_id , adapter_weights , to_hbm )
271322 return
272323
273324 metadata .status = "loading"
0 commit comments