diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index 1d303c1f..e54c7261 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -99,6 +99,7 @@ def __init__( adapters_dir_path: str, hbm_memory_budget: int, cpu_memory_budget: int, + total_slots: int, ): """Initializes the AdapterTensorStore.""" self.engine = engine # Possibly MaxEngine object @@ -119,8 +120,27 @@ def __init__( self.running_requests: int = ( 0 # Number of async tasks which are in "loading" state ) + self.decoding_adapters_cache: Dict[str, Any] = {} + + # TODO: Make dtype configurable for the scale factor array + self.adapters_scale_factor = jnp.empty(1, dtype=jnp.bfloat16) + + self.total_slots = total_slots self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety + def _get_adapter_scale_factor(self, adapter_id: str): + """ + Internal: Get the LoRA scale_factor using the adapter_id. + """ + adapter_config = self.adapter_registry[adapter_id].config + lora_scale_factor = float(1) + + if "r" in adapter_config and "lora_alpha" in adapter_config: + lora_rank = int(adapter_config["r"]) + lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank + + return lora_scale_factor + # --- Unsafe Internal methods which assumes that lock is held --- def _unsafe_transfer_to_hbm(self, adapter_id: str): """ @@ -207,6 +227,90 @@ def _unsafe_unload_adapter(self, adapter_id: str): metadata.size_hbm = 0 metadata.size_cpu = 0 + def _initialize_decoding_adapters_cache(self, adapter_weights): + """ + Create a new PyTree with zero tensors at the paths corresponding to + non-None leaves in the input PyTree. The zero tensors have an added + dimension of size `self.totol_slots`. + Args: + adatper_weights: The input PyTree, whose structure will be mirrored. + Returns: + A new PyTree with zero Tensors or None values, mirroring the structure + of the input PyTree. + """ + + def create_zero_leaf(leaf): + if leaf is not None: + original_shape = leaf.shape + if not original_shape: # handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = ( + self.total_slots, + ) + original_shape # Prepend a new dimension + + return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype) + else: + return None # Maintain None structure for None leaves + + self.adapters_scale_factor = jnp.ones(self.total_slots, dtype=jnp.bfloat16) + return jax.tree_util.tree_map(create_zero_leaf, adapter_weights) + + def insert_adapter_in_cache(self, adapter_id: str, slot_id: int): + """ + Insert the specific adapter tensors into a slot in the + serving_adapters_cache. + Args: + adapter_id: The id of the adapter, whose tensors will be inserted + slot_id: The id of slot, which represents the index in the + serving_adapter_cache where the adapter tensors will be inserted. + """ + + def insert_leaf(dest_leaf, source_leaf): + if dest_leaf is not None and source_leaf is not None: + return dest_leaf.at[slot_id].set( + source_leaf + ) # Insert at the specific index + elif dest_leaf is not None: + return dest_leaf # If source_leaf is None, keep the zero_leaf as is + elif ( + source_leaf is not None + ): # In this case the adapters have different target modules + original_shape = source_leaf.shape + if not original_shape: # Handle scalar case + zero_tensor_shape = (self.total_slots,) + else: + zero_tensor_shape = (self.total_slots,) + original_shape + new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype) + return new_dest_leaf.at[slot_id].set(source_leaf) + else: + return None # If both are None, return None + + if adapter_id == "": + logging.info( + "Empty adapter id. No LoRA tensors added to adapter_tensorstore cache" + ) + return + + asyncio.run(self.load_adapter(adapter_id, None, True)) + + adapter_weights = self.loaded_adapters_hbm[adapter_id] + + if not self.decoding_adapters_cache: + self.decoding_adapters_cache = self._initialize_decoding_adapters_cache( + adapter_weights + ) + + adapter_scale_factor = jnp.bfloat16( + self._get_adapter_scale_factor(adapter_id) + ) + self.adapters_scale_factor = self.adapters_scale_factor.at[slot_id].set( + adapter_scale_factor + ) + self.decoding_adapters_cache = jax.tree_util.tree_map( + insert_leaf, self.decoding_adapters_cache, adapter_weights + ) + # --- Public Methods (Acquire lock, then call unsafe methods) --- async def register_adapter( diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 3f0cb459..7b774e7a 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -1018,6 +1018,10 @@ def _insert_if_possible( # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to # insert as many sequences as possible. + adapter_tensorstore = None + if self._generate_adapterstore and idx < len(self._generate_adapterstore): + adapter_tensorstore = self._generate_adapterstore[idx] + while True: my_slots_size = my_slots.qsize() @@ -1086,8 +1090,13 @@ def _insert_if_possible( new_request.prefill_result, decode_state, slot=slot, - # request_id=new_request.request_id, ) + + if adapter_tensorstore: + adapter_tensorstore.insert_adapter_in_cache( + new_request.adapter_id, slot + ) + ThreadDebugLog( thread_name, f"Generate slice {idx} filled slot {slot} at step " @@ -1227,6 +1236,10 @@ def _generate_thread(self, idx: int): my_generate_backlog = self._generate_backlogs[idx] my_detokenize_backlog = self._detokenize_backlogs[idx] + adapter_tensorstore = None + if self._generate_adapterstore and idx < len(self._generate_adapterstore): + adapter_tensorstore = self._generate_adapterstore[idx] + # Keep track of what step tokens were generated at. generate_timestep = 0 # State to store things like running kv cache in. @@ -1292,6 +1305,24 @@ def _generate_thread(self, idx: int): my_slots.qsize() < max_concurrent_decodes ), "At this point we must have some requests inserted into the slots." + if adapter_tensorstore: + decoding_adapters_params = adapter_tensorstore.decoding_adapters_cache + adapters_scale_factor = adapter_tensorstore.adapters_scale_factor + b = adapters_scale_factor.shape[0] + + # Reshaped the scale_factors array to 4-D to align with shape of + # the vectors `(batch, hidden_size, num_heads, head_dim)`. + reshaped_scale_factors = adapters_scale_factor.reshape((b, 1, 1, 1)) + + lora_state = {} + lora_state["scale_factor"] = reshaped_scale_factors + lora_state["lora_params"] = decoding_adapters_params + + if isinstance(decode_state, dict): + decode_state["lora_state"] = lora_state + else: # flax.struct.dataclass + decode_state = decode_state.replace(lora_state=lora_state) + # Now we actually take a generate step on requests in the slots. decode_state, sampled_tokens = generate_engine.generate( generate_params, decode_state diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 733fbba9..28e7c16a 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -174,6 +174,7 @@ def create_driver( shared_adapterstore = [] if lora_input_adapters_path: + # TODO: Make hbm_memory_budget and cpu_memory_budget configurable for pe in engines.prefill_engines: prefill_adapterstore.append( adapterstore.AdapterTensorStore( @@ -181,9 +182,10 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=pe.max_concurrent_decodes, ) ) - # TODO: Make hbm_memory_budget and cpu_memory_budget configurable + for ge in engines.generate_engines: generate_adapterstore.append( adapterstore.AdapterTensorStore( @@ -191,6 +193,7 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=ge.max_concurrent_decodes, ) ) @@ -201,6 +204,7 @@ def create_driver( adapters_dir_path=lora_input_adapters_path, hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=ie.max_concurrent_decodes, ) ) @@ -315,6 +319,9 @@ def run( "Not starting Prometheus server: --prometheus_port flag not set" ) + if multi_sampling and lora_input_adapters_path: + raise ValueError("LoRA adapters is not enabled for multi_sampling mode.") + driver = create_driver( config, devices, diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index fb37f293..7466659f 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -32,7 +32,7 @@ import functools from dataclasses import asdict -from typing import Any, Callable, Optional, Tuple +from typing import Any, Dict, Callable, Optional, Tuple import jax import jax.numpy as jnp @@ -71,6 +71,7 @@ class DecodeState: generate_cache_index: int generate_lengths: jax.Array generate_tokens: jax.Array + lora_state: Optional[Dict[str, Any]] = None class TestEngine(engine_api.Engine): @@ -509,6 +510,7 @@ def init_decode_state(self) -> DecodeState: generate_tokens=jnp.zeros( (self.generate_cache_batch, 1), dtype=jnp.int32 ), + lora_state={}, ) @property diff --git a/jetstream/tests/core/lora/test_adapter_tensorstore.py b/jetstream/tests/core/lora/test_adapter_tensorstore.py index b48a8360..2115d77a 100644 --- a/jetstream/tests/core/lora/test_adapter_tensorstore.py +++ b/jetstream/tests/core/lora/test_adapter_tensorstore.py @@ -145,6 +145,7 @@ async def asyncSetUp(self): adapters_dir_path=self.adapters_dir_path, hbm_memory_budget=self.hbm_budget, cpu_memory_budget=self.cpu_budget, + total_slots=8, ) # Pre-register adapters for most tests to simplify setup diff --git a/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py index 742fc300..7e2fd47e 100644 --- a/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -123,6 +123,7 @@ async def _setup_driver_with_adapterstore( adapters_dir_path="/tmp/", hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=8, ) generate_adapterstore = adapterstore.AdapterTensorStore( @@ -130,6 +131,7 @@ async def _setup_driver_with_adapterstore( adapters_dir_path="/tmp/", hbm_memory_budget=20 * (1024**3), # 20 GB HBM cpu_memory_budget=100 * (1024**3), # 100 GB RAM + total_slots=8, ) await prefill_adapterstore.register_adapter(