@@ -99,6 +99,7 @@ def __init__(
9999 adapters_dir_path : str ,
100100 hbm_memory_budget : int ,
101101 cpu_memory_budget : int ,
102+ total_slots : int ,
102103 ):
103104 """Initializes the AdapterTensorStore."""
104105 self .engine = engine # Possibly MaxEngine object
@@ -119,6 +120,8 @@ def __init__(
119120 self .running_requests : int = (
120121 0 # Number of async tasks which are in "loading" state
121122 )
123+ self .decoding_adapters_cache : Dict [str , Any ] = {}
124+ self .total_slots = total_slots
122125 self .lock = asyncio .Lock () # Use an asyncio Lock for thread safety
123126
124127 # --- Unsafe Internal methods which assumes that lock is held ---
@@ -207,6 +210,70 @@ def _unsafe_unload_adapter(self, adapter_id: str):
207210 metadata .size_hbm = 0
208211 metadata .size_cpu = 0
209212
213+ def _initialize_decoding_adapters_cache (self , adapter_weights ):
214+ """
215+ Create a new PyTree with zero tensors at the paths corresponding to non-None leaves
216+ in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`.
217+ Args:
218+ adatper_weights: The input PyTree, whose structure will be mirrored.
219+ Returns:
220+ A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree.
221+ """
222+ def create_zero_leaf (leaf ):
223+ if leaf is not None :
224+ original_shape = leaf .shape
225+ if not original_shape : # handle scalar case
226+ zero_tensor_shape = (self .total_slots ,)
227+ else :
228+ zero_tensor_shape = (self .total_slots ,) + original_shape # Prepend a new dimension
229+
230+ return jnp .zeros (zero_tensor_shape , dtype = leaf .dtype )
231+ else :
232+ return None # Maintain None structure for None leaves
233+
234+ return jax .tree_util .tree_map (create_zero_leaf , adapter_weights )
235+
236+
237+ def insert_adapter_in_cache (self , adapter_id : str , slot_id : int ):
238+ """
239+ Insert the specific adapter tensors into a slot in the serving_adapters_cache.
240+ Args:
241+ adapter_id: The id of the adapter, whose tensors will be inserted
242+ slot_id: The id of slot, which represents the index in the serving_adapter_cache
243+ where the adapter tensors will be inserted.
244+ """
245+
246+ def insert_leaf (dest_leaf , source_leaf ):
247+ if dest_leaf is not None and source_leaf is not None :
248+ return dest_leaf .at [slot_id ].set (source_leaf ) # Insert at the specific index
249+ elif dest_leaf is not None :
250+ return dest_leaf # If source_leaf is None, keep the zero_leaf as is
251+ elif source_leaf is not None : # In this case the adapters have different target modules
252+ original_shape = source_leaf .shape
253+ if not original_shape : # Handle scalar case
254+ zero_tensor_shape = (self .total_slots ,)
255+ else :
256+ zero_tensor_shape = (self .total_slots ,) + original_shape
257+ new_dest_leaf = jnp .zeros (zero_tensor_shape , dtype = source_leaf .dtype )
258+ return new_dest_leaf .at [slot_id ].set (source_leaf )
259+ else :
260+ return None # If both are None, return None
261+
262+ if adapter_id == "" :
263+ logging .info ("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore." )
264+ return
265+
266+ asyncio .run (self .load_adapter (adapter_id , None , True ))
267+
268+ adapter_weights = self .loaded_adapters_hbm [adapter_id ]
269+
270+ if not self .decoding_adapters_cache :
271+ self .decoding_adapters_cache = self ._initialize_decoding_adapters_cache (adapter_weights )
272+
273+ self .decoding_adapters_cache = jax .tree_util .tree_map (insert_leaf ,
274+ self .decoding_adapters_cache ,
275+ adapter_weights )
276+
210277 # --- Public Methods (Acquire lock, then call unsafe methods) ---
211278
212279 async def register_adapter (
0 commit comments