@@ -66,12 +66,29 @@ class AdapterMetadata:
6666
6767
6868class AdapterTensorStore :
69+ """
70+ Manages the storage and retrieval of LoRA adapter weights, handling
71+ placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM.
72+
73+ This class implements an LRU (Least Recently Used) eviction policy
74+ to manage memory usage. It supports asynchronous loading and unloading
75+ of adapters to avoid blocking the main inference thread.
76+
77+ Args:
78+ hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for
79+ storing LoRA adapter weights.
80+ cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use
81+ for storing LoRA adapter weights.
82+ """
83+
84+
6985 def __init__ (self , hbm_memory_budget : int , cpu_memory_budget : int ):
86+ """Initializes the AdapterTensorStore."""
7087 self .hbm_memory_budget = hbm_memory_budget
7188 self .cpu_memory_budget = cpu_memory_budget
7289 self .adapter_registry : Dict [str , AdapterMetadata ] = {} # All known adapters
73- self .loaded_adapters_hbm : Dict [str , jnp .ndarray ] = {} # adapter_id -> Unified LoRA params (in HBM)
74- self .loaded_adapters_cpu : Dict [str , np .ndarray ] = {} # adapter_id -> Unified LoRA params (in CPU RAM)
90+ self .loaded_adapters_hbm : Dict [str , jnp .ndarray ] = {} # adapter_id -> LoRA params (in HBM)
91+ self .loaded_adapters_cpu : Dict [str , np .ndarray ] = {} # adapter_id -> LoRA params (in CPU RAM)
7592 self .current_hbm_usage : int = 0
7693 self .current_cpu_usage : int = 0
7794 self .running_requests : int = 0 # Number of async tasks which are in "loading" state
@@ -80,6 +97,18 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
8097
8198 def register_adapter (self , adapter_id : str , adapter_path : str , config : Dict [str , Any ]):
8299 """Registers a new LoRA adatper."""
100+ """
101+ Registers a LoRA adapter with the TensorStore. This does *not* load
102+ the adapter; it simply adds metadata about the adapter to the registry.
103+
104+ Args:
105+ adapter_id (str): A unique identifier for the adapter.
106+ adapter_path (str): The path to the adapter weights (file or directory).
107+ config (dict): Config of the loRA adapter.
108+
109+ Raises:
110+ ValueError: If an adapter with the same ID is already registered.
111+ """
83112 if adapter_id in self .adapter_registry :
84113 raise ValueError (f"Adapter with ID '{ adapter_id } ' already registered." )
85114 self .adapter_registry [adapter_id ] = AdapterMetadata (
@@ -162,17 +191,36 @@ async def load_adapter(
162191 self ,
163192 adapter_id : str ,
164193 adapter_weights = None ,
165- to_hbm : bool = True ,
166- force_load : bool = False ):
167- """Loads a LoRA adapter's weights, managing HBM and CPU memory."""
194+ to_hbm : bool = True ):
195+ """
196+ Loads a LoRA adapter's weights into memory (either HBM or CPU RAM).
197+
198+ This method is asynchronous to avoid blocking the main thread during
199+ potentially slow I/O operations. It handles:
200+ - Checking if the adapter is already loaded.
201+ - Checking if there's enough memory (and evicting if necessary).
202+ - Loading the weights (in a separate thread).
203+ - Updating the adapter's status and metadata.
204+
205+ Args:
206+ adapter_id (str): The ID of the adapter to load.
207+ adapter_weights: In the form of a PyTree.
208+ to_hbm (bool): Whether to load the adapter into HBM (True) or
209+ CPU RAM (False). Defaults to True (HBM).
210+
211+ Raises:
212+ ValueError: If the adapter ID is not registered.
213+ RuntimeError: If there is not enough memory to load the adapter,
214+ and eviction fails to free up enough space.
215+ """
168216
169217 if adapter_id not in self .adapter_registry :
170218 raise ValueError (f"Adapter with ID '{ adapter_id } ' not registered." )
171219
172220 metadata = self .adapter_registry [adapter_id ]
173221
174222 async with self .lock : # Acquire lock for thread safety
175- if not force_load and metadata .status in ("loaded_hbm" , "loaded_cpu" ):
223+ if metadata .status in ("loaded_hbm" , "loaded_cpu" ):
176224 metadata .last_accessed = time .time ()
177225
178226 # if already loaded in HBM and we want HBM, or
@@ -195,7 +243,7 @@ async def load_adapter(
195243 await asyncio .sleep (0.1 ) # Short sleep to avoid busy-waiting
196244
197245 # Make recursive call to load_adapter to copy to device
198- await self .load_adapter (adapter_id , adapter_weights , to_hbm , force_load )
246+ await self .load_adapter (adapter_id , adapter_weights , to_hbm )
199247 return
200248
201249 metadata .status = "loading"
0 commit comments