Skip to content

Commit e4d22bf

Browse files
committed
JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters.
1 parent a6a5cd1 commit e4d22bf

3 files changed

Lines changed: 87 additions & 6 deletions

File tree

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class AdapterMetadata:
6666

6767

6868
class AdapterTensorStore:
69-
def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
69+
def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots: int):
7070
self.hbm_memory_budget = hbm_memory_budget
7171
self.cpu_memory_budget = cpu_memory_budget
7272
self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters
@@ -75,6 +75,8 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
7575
self.current_hbm_usage: int = 0
7676
self.current_cpu_usage: int = 0
7777
self.running_requests: int = 0 # Number of async tasks which are in "loading" state
78+
self.decoding_adapters_cache: Dict[str, Any] = {}
79+
self.total_slots = total_slots
7880
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety
7981

8082

@@ -145,6 +147,76 @@ async def _transfer_to_cpu(self, adapter_id: str):
145147
metadata.last_accessed = time.time()
146148

147149

150+
def _initialize_decoding_adapters_cache(self, adapter_weights):
151+
"""
152+
Create a new PyTree with zero tensors at the paths corresponding to non-None leaves
153+
in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`.
154+
155+
Args:
156+
adatper_weights: The input PyTree, whose structure will be mirrored.
157+
158+
Returns:
159+
A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree.
160+
"""
161+
def create_zero_leaf(leaf):
162+
if leaf is not None:
163+
original_shape = leaf.shape
164+
if not original_shape: # handle scalar case
165+
zero_tensor_shape = (self.total_slots,)
166+
else:
167+
zero_tensor_shape = (self.total_slots,) + original_shape # Prepend a new dimension
168+
169+
return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype)
170+
else:
171+
return None # Maintain None structure for None leaves
172+
173+
return jax.tree_util.tree_map(create_zero_leaf, adapter_weights)
174+
175+
176+
def insert_adapter_in_cache(self, adapter_id: str, slot_id: int):
177+
"""
178+
Insert the specific adapter tensors into a slot in the serving_adapters_cache.
179+
180+
Args:
181+
adapter_id: The id of the adapter, whose tensors will be inserted
182+
slot_id: The id of slot, which represents the index in the serving_adapter_cache
183+
where the adapter tensors will be inserted.
184+
"""
185+
186+
def insert_leaf(dest_leaf, source_leaf):
187+
if dest_leaf is not None and source_leaf is not None:
188+
return dest_leaf.at[slot_id].set(source_leaf) # Insert at the specific index
189+
elif dest_leaf is not None:
190+
return dest_leaf # If source_leaf is None, keep the zero_leaf as is
191+
elif source_leaf is not None: # In this case the adapters have different target modules
192+
original_shape = source_leaf.shape
193+
if not original_shape: # Handle scalar case
194+
zero_tensor_shape = (self.total_slots,)
195+
else:
196+
zero_tensor_shape = (self.total_slots,) + original_shape
197+
new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype)
198+
return new_dest_leaf.at[slot_id].set(source_leaf)
199+
else:
200+
return None # If both are None, return None
201+
202+
if adapter_id == "":
203+
logging.info("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore.")
204+
return
205+
206+
metadata = self.adapter_registry[adapter_id]
207+
208+
asyncio.run(self.load_adapter(adapter_id, True))
209+
210+
adapter_weights = self.loaded_adapters_hbm[adapter_id]
211+
212+
if not self.decoding_adapters_cache:
213+
self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(adapter_weights)
214+
215+
self.decoding_adapters_cache = jax.tree_util.tree_map(insert_leaf,
216+
self.decoding_adapters_cache,
217+
adapter_weights)
218+
219+
148220
async def get_hbm_loaded_adapters(self):
149221
"""Returns a comma separated list of adapters loaded into HBM."""
150222

jetstream/core/orchestrator.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,6 @@ def __init__(
281281
if generate_params is None:
282282
raise ValueError("No generate parameter provided.")
283283

284-
self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore(
285-
hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM
286-
cpu_memory_budget=(100 * (1024 ** 3))) # 100 GB RAM
287-
288284
logger.info(
289285
"Initializing the driver with %d prefill engines and %d "
290286
"generate engines in %s mode",
@@ -301,6 +297,15 @@ def __init__(
301297
self._metrics_collector = metrics_collector
302298
self._multi_sampling = multi_sampling
303299

300+
total_slots = 0
301+
for engine in self._generate_engines:
302+
total_slots += engine.max_concurrent_decodes
303+
304+
self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore(
305+
hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM
306+
cpu_memory_budget=(100 * (1024 ** 3)), # 100 GB RAM
307+
total_slots=total_slots)
308+
304309
# Stages 1-4 represent the life cycle of a request.
305310
# Stage 1
306311
# At first, a request is placed here in order to get prefilled.
@@ -930,6 +935,9 @@ def _insert_if_possible(
930935
slot=slot,
931936
#request_id=new_request.request_id,
932937
)
938+
939+
self._adapter_tensorstore.insert_adapter_in_cache(new_request.adapter_id, slot)
940+
933941
ThreadDebugLog(
934942
thread_name,
935943
f"Generate slice {idx} filled slot {slot} at step "
@@ -1136,7 +1144,7 @@ def _generate_thread(self, idx: int):
11361144

11371145
# Now we actually take a generate step on requests in the slots.
11381146
decode_state, sampled_tokens = generate_engine.generate(
1139-
generate_params, decode_state
1147+
generate_params, decode_state, self._adapter_tensorstore.decoding_adapters_cache,
11401148
)
11411149
sampled_tokens.copy_to_host_async()
11421150
# Respond to detokenization backpressure.

jetstream/engine/engine_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def generate(
197197
params: Params,
198198
decode_state: DecodeState,
199199
sampler: Optional[Callable[[Any], Any]] = None,
200+
lora_params: Params = None,
200201
) -> Tuple[DecodeState, ResultTokens]:
201202
"""Generates tokens for each sequence being decoded in parallel.
202203

0 commit comments

Comments
 (0)