@@ -121,9 +121,26 @@ def __init__(
121121 0 # Number of async tasks which are in "loading" state
122122 )
123123 self .decoding_adapters_cache : Dict [str , Any ] = {}
124+
125+ # TODO: Make dtype configurable for the scale factor array
126+ self .adapters_scale_factor = jnp .empty (1 , dtype = jnp .bfloat16 )
127+
124128 self .total_slots = total_slots
125129 self .lock = asyncio .Lock () # Use an asyncio Lock for thread safety
126130
131+ def _get_adapter_scale_factor (self , adapter_id : str ):
132+ """
133+ Internal: Get the LoRA scale_factor using the adapter_id.
134+ """
135+ adapter_config = self .adapter_registry [adapter_id ].config
136+ lora_scale_factor = float (1 )
137+
138+ if "r" in adapter_config and "lora_alpha" in adapter_config :
139+ lora_rank = int (adapter_config ["r" ])
140+ lora_scale_factor = float (adapter_config ["lora_alpha" ]) / lora_rank
141+
142+ return lora_scale_factor
143+
127144 # --- Unsafe Internal methods which assumes that lock is held ---
128145 def _unsafe_transfer_to_hbm (self , adapter_id : str ):
129146 """
@@ -212,67 +229,87 @@ def _unsafe_unload_adapter(self, adapter_id: str):
212229
213230 def _initialize_decoding_adapters_cache (self , adapter_weights ):
214231 """
215- Create a new PyTree with zero tensors at the paths corresponding to non-None leaves
216- in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`.
232+ Create a new PyTree with zero tensors at the paths corresponding to
233+ non-None leaves in the input PyTree. The zero tensors have an added
234+ dimension of size `self.totol_slots`.
217235 Args:
218236 adatper_weights: The input PyTree, whose structure will be mirrored.
219237 Returns:
220- A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree.
238+ A new PyTree with zero Tensors or None values, mirroring the structure
239+ of the input PyTree.
221240 """
241+
222242 def create_zero_leaf (leaf ):
223243 if leaf is not None :
224244 original_shape = leaf .shape
225- if not original_shape : # handle scalar case
245+ if not original_shape : # handle scalar case
226246 zero_tensor_shape = (self .total_slots ,)
227247 else :
228- zero_tensor_shape = (self .total_slots ,) + original_shape # Prepend a new dimension
248+ zero_tensor_shape = (
249+ self .total_slots ,
250+ ) + original_shape # Prepend a new dimension
229251
230252 return jnp .zeros (zero_tensor_shape , dtype = leaf .dtype )
231253 else :
232- return None # Maintain None structure for None leaves
254+ return None # Maintain None structure for None leaves
233255
256+ self .adapters_scale_factor = jnp .ones (self .total_slots , dtype = jnp .bfloat16 )
234257 return jax .tree_util .tree_map (create_zero_leaf , adapter_weights )
235258
236-
237259 def insert_adapter_in_cache (self , adapter_id : str , slot_id : int ):
238260 """
239- Insert the specific adapter tensors into a slot in the serving_adapters_cache.
261+ Insert the specific adapter tensors into a slot in the
262+ serving_adapters_cache.
240263 Args:
241264 adapter_id: The id of the adapter, whose tensors will be inserted
242- slot_id: The id of slot, which represents the index in the serving_adapter_cache
243- where the adapter tensors will be inserted.
265+ slot_id: The id of slot, which represents the index in the
266+ serving_adapter_cache where the adapter tensors will be inserted.
244267 """
245268
246269 def insert_leaf (dest_leaf , source_leaf ):
247270 if dest_leaf is not None and source_leaf is not None :
248- return dest_leaf .at [slot_id ].set (source_leaf ) # Insert at the specific index
271+ return dest_leaf .at [slot_id ].set (
272+ source_leaf
273+ ) # Insert at the specific index
249274 elif dest_leaf is not None :
250- return dest_leaf # If source_leaf is None, keep the zero_leaf as is
251- elif source_leaf is not None : # In this case the adapters have different target modules
275+ return dest_leaf # If source_leaf is None, keep the zero_leaf as is
276+ elif (
277+ source_leaf is not None
278+ ): # In this case the adapters have different target modules
252279 original_shape = source_leaf .shape
253- if not original_shape : # Handle scalar case
280+ if not original_shape : # Handle scalar case
254281 zero_tensor_shape = (self .total_slots ,)
255282 else :
256283 zero_tensor_shape = (self .total_slots ,) + original_shape
257284 new_dest_leaf = jnp .zeros (zero_tensor_shape , dtype = source_leaf .dtype )
258285 return new_dest_leaf .at [slot_id ].set (source_leaf )
259286 else :
260- return None # If both are None, return None
287+ return None # If both are None, return None
261288
262289 if adapter_id == "" :
263- logging .info ("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore." )
290+ logging .info (
291+ "Empty adapter id. No LoRA tensors added to adapter_tensorstore cache"
292+ )
264293 return
265294
266295 asyncio .run (self .load_adapter (adapter_id , None , True ))
267296
268297 adapter_weights = self .loaded_adapters_hbm [adapter_id ]
269298
270299 if not self .decoding_adapters_cache :
271- self .decoding_adapters_cache = self ._initialize_decoding_adapters_cache (adapter_weights )
300+ self .decoding_adapters_cache = self ._initialize_decoding_adapters_cache (
301+ adapter_weights
302+ )
272303
273- self .decoding_adapters_cache = jax .tree_util .tree_map (insert_leaf ,
274- self .decoding_adapters_cache ,
275- adapter_weights )
304+ adapter_scale_factor = jnp .bfloat16 (
305+ self ._get_adapter_scale_factor (adapter_id )
306+ )
307+ self .adapters_scale_factor = self .adapters_scale_factor .at [slot_id ].set (
308+ adapter_scale_factor
309+ )
310+ self .decoding_adapters_cache = jax .tree_util .tree_map (
311+ insert_leaf , self .decoding_adapters_cache , adapter_weights
312+ )
276313
277314 # --- Public Methods (Acquire lock, then call unsafe methods) ---
278315
0 commit comments