@@ -66,7 +66,7 @@ class AdapterMetadata:
6666
6767
6868class AdapterTensorStore :
69- def __init__ (self , hbm_memory_budget : int , cpu_memory_budget : int ):
69+ def __init__ (self , hbm_memory_budget : int , cpu_memory_budget : int , total_slots : int ):
7070 self .hbm_memory_budget = hbm_memory_budget
7171 self .cpu_memory_budget = cpu_memory_budget
7272 self .adapter_registry : Dict [str , AdapterMetadata ] = {} # All known adapters
@@ -75,6 +75,8 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
7575 self .current_hbm_usage : int = 0
7676 self .current_cpu_usage : int = 0
7777 self .running_requests : int = 0 # Number of async tasks which are in "loading" state
78+ self .decoding_adapters_cache : Dict [str , Any ] = {}
79+ self .total_slots = total_slots
7880 self .lock = asyncio .Lock () # Use an asyncio Lock for thread safety
7981
8082
@@ -145,6 +147,76 @@ async def _transfer_to_cpu(self, adapter_id: str):
145147 metadata .last_accessed = time .time ()
146148
147149
150+ def _initialize_decoding_adapters_cache (self , adapter_weights ):
151+ """
152+ Create a new PyTree with zero tensors at the paths corresponding to non-None leaves
153+ in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`.
154+
155+ Args:
156+ adatper_weights: The input PyTree, whose structure will be mirrored.
157+
158+ Returns:
159+ A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree.
160+ """
161+ def create_zero_leaf (leaf ):
162+ if leaf is not None :
163+ original_shape = leaf .shape
164+ if not original_shape : # handle scalar case
165+ zero_tensor_shape = (self .total_slots ,)
166+ else :
167+ zero_tensor_shape = (self .total_slots ,) + original_shape # Prepend a new dimension
168+
169+ return jnp .zeros (zero_tensor_shape , dtype = leaf .dtype )
170+ else :
171+ return None # Maintain None structure for None leaves
172+
173+ return jax .tree_util .tree_map (create_zero_leaf , adapter_weights )
174+
175+
176+ def insert_adapter_in_cache (self , adapter_id : str , slot_id : int ):
177+ """
178+ Insert the specific adapter tensors into a slot in the serving_adapters_cache.
179+
180+ Args:
181+ adapter_id: The id of the adapter, whose tensors will be inserted
182+ slot_id: The id of slot, which represents the index in the serving_adapter_cache
183+ where the adapter tensors will be inserted.
184+ """
185+
186+ def insert_leaf (dest_leaf , source_leaf ):
187+ if dest_leaf is not None and source_leaf is not None :
188+ return dest_leaf .at [slot_id ].set (source_leaf ) # Insert at the specific index
189+ elif dest_leaf is not None :
190+ return dest_leaf # If source_leaf is None, keep the zero_leaf as is
191+ elif source_leaf is not None : # In this case the adapters have different target modules
192+ original_shape = source_leaf .shape
193+ if not original_shape : # Handle scalar case
194+ zero_tensor_shape = (self .total_slots ,)
195+ else :
196+ zero_tensor_shape = (self .total_slots ,) + original_shape
197+ new_dest_leaf = jnp .zeros (zero_tensor_shape , dtype = source_leaf .dtype )
198+ return new_dest_leaf .at [slot_id ].set (source_leaf )
199+ else :
200+ return None # If both are None, return None
201+
202+ if adapter_id == "" :
203+ logging .info ("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore." )
204+ return
205+
206+ metadata = self .adapter_registry [adapter_id ]
207+
208+ asyncio .run (self .load_adapter (adapter_id , True ))
209+
210+ adapter_weights = self .loaded_adapters_hbm [adapter_id ]
211+
212+ if not self .decoding_adapters_cache :
213+ self .decoding_adapters_cache = self ._initialize_decoding_adapters_cache (adapter_weights )
214+
215+ self .decoding_adapters_cache = jax .tree_util .tree_map (insert_leaf ,
216+ self .decoding_adapters_cache ,
217+ adapter_weights )
218+
219+
148220 async def get_hbm_loaded_adapters (self ):
149221 """Returns a comma separated list of adapters loaded into HBM."""
150222
0 commit comments