Skip to content

Commit 3ff6383

Browse files
committed
After merging with main, some of the code gets overridden. So adding back the code to support multi-LoRA adapters into same batch.
1 parent a38b686 commit 3ff6383

4 files changed

Lines changed: 78 additions & 5 deletions

File tree

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
adapters_dir_path: str,
100100
hbm_memory_budget: int,
101101
cpu_memory_budget: int,
102+
total_slots: int,
102103
):
103104
"""Initializes the AdapterTensorStore."""
104105
self.engine = engine # Possibly MaxEngine object
@@ -119,6 +120,8 @@ def __init__(
119120
self.running_requests: int = (
120121
0 # Number of async tasks which are in "loading" state
121122
)
123+
self.decoding_adapters_cache: Dict[str, Any] = {}
124+
self.total_slots = total_slots
122125
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety
123126

124127
# --- Unsafe Internal methods which assumes that lock is held ---
@@ -207,6 +210,70 @@ def _unsafe_unload_adapter(self, adapter_id: str):
207210
metadata.size_hbm = 0
208211
metadata.size_cpu = 0
209212

213+
def _initialize_decoding_adapters_cache(self, adapter_weights):
214+
"""
215+
Create a new PyTree with zero tensors at the paths corresponding to non-None leaves
216+
in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`.
217+
Args:
218+
adatper_weights: The input PyTree, whose structure will be mirrored.
219+
Returns:
220+
A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree.
221+
"""
222+
def create_zero_leaf(leaf):
223+
if leaf is not None:
224+
original_shape = leaf.shape
225+
if not original_shape: # handle scalar case
226+
zero_tensor_shape = (self.total_slots,)
227+
else:
228+
zero_tensor_shape = (self.total_slots,) + original_shape # Prepend a new dimension
229+
230+
return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype)
231+
else:
232+
return None # Maintain None structure for None leaves
233+
234+
return jax.tree_util.tree_map(create_zero_leaf, adapter_weights)
235+
236+
237+
def insert_adapter_in_cache(self, adapter_id: str, slot_id: int):
238+
"""
239+
Insert the specific adapter tensors into a slot in the serving_adapters_cache.
240+
Args:
241+
adapter_id: The id of the adapter, whose tensors will be inserted
242+
slot_id: The id of slot, which represents the index in the serving_adapter_cache
243+
where the adapter tensors will be inserted.
244+
"""
245+
246+
def insert_leaf(dest_leaf, source_leaf):
247+
if dest_leaf is not None and source_leaf is not None:
248+
return dest_leaf.at[slot_id].set(source_leaf) # Insert at the specific index
249+
elif dest_leaf is not None:
250+
return dest_leaf # If source_leaf is None, keep the zero_leaf as is
251+
elif source_leaf is not None: # In this case the adapters have different target modules
252+
original_shape = source_leaf.shape
253+
if not original_shape: # Handle scalar case
254+
zero_tensor_shape = (self.total_slots,)
255+
else:
256+
zero_tensor_shape = (self.total_slots,) + original_shape
257+
new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype)
258+
return new_dest_leaf.at[slot_id].set(source_leaf)
259+
else:
260+
return None # If both are None, return None
261+
262+
if adapter_id == "":
263+
logging.info("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore.")
264+
return
265+
266+
asyncio.run(self.load_adapter(adapter_id, None, True))
267+
268+
adapter_weights = self.loaded_adapters_hbm[adapter_id]
269+
270+
if not self.decoding_adapters_cache:
271+
self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(adapter_weights)
272+
273+
self.decoding_adapters_cache = jax.tree_util.tree_map(insert_leaf,
274+
self.decoding_adapters_cache,
275+
adapter_weights)
276+
210277
# --- Public Methods (Acquire lock, then call unsafe methods) ---
211278

212279
async def register_adapter(

jetstream/core/orchestrator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,7 @@ def _insert_if_possible(
10311031
# we can still generate if we can't insert. We do this in a while loop to
10321032
# insert as many sequences as possible.
10331033
adapter_tensorstore = None
1034-
if self._generate_adapterstore:
1034+
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
10351035
adapter_tensorstore = self._generate_adapterstore[idx]
10361036

10371037
while True:
@@ -1102,7 +1102,6 @@ def _insert_if_possible(
11021102
new_request.prefill_result,
11031103
decode_state,
11041104
slot=slot,
1105-
# request_id=new_request.request_id,
11061105
)
11071106

11081107
if adapter_tensorstore:
@@ -1321,10 +1320,11 @@ def _generate_thread(self, idx: int):
13211320

13221321
if adapter_tensorstore:
13231322
decoding_adapters_cache = adapter_tensorstore.decoding_adapters_cache
1323+
#decode_state["lora_adapter_cache"] = decoding_adapters_cache
13241324

13251325
# Now we actually take a generate step on requests in the slots.
13261326
decode_state, sampled_tokens = generate_engine.generate(
1327-
generate_params, decode_state, decoding_adapters_cache
1327+
generate_params, decode_state
13281328
)
13291329
sampled_tokens.copy_to_host_async()
13301330
# Respond to detokenization backpressure.

jetstream/core/server_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,23 +174,26 @@ def create_driver(
174174
shared_adapterstore = []
175175

176176
if lora_input_adapters_path:
177+
# TODO: Make hbm_memory_budget and cpu_memory_budget configurable
177178
for pe in engines.prefill_engines:
178179
prefill_adapterstore.append(
179180
adapterstore.AdapterTensorStore(
180181
engine=pe,
181182
adapters_dir_path=lora_input_adapters_path,
182183
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
183184
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
185+
total_slots=pe.max_concurrent_decodes,
184186
)
185187
)
186-
# TODO: Make hbm_memory_budget and cpu_memory_budget configurable
188+
187189
for ge in engines.generate_engines:
188190
generate_adapterstore.append(
189191
adapterstore.AdapterTensorStore(
190192
engine=ge,
191193
adapters_dir_path=lora_input_adapters_path,
192194
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
193195
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
196+
total_slots=ge.max_concurrent_decodes,
194197
)
195198
)
196199

@@ -201,6 +204,7 @@ def create_driver(
201204
adapters_dir_path=lora_input_adapters_path,
202205
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
203206
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
207+
total_slots=ie.max_concurrent_decodes,
204208
)
205209
)
206210

@@ -315,6 +319,9 @@ def run(
315319
"Not starting Prometheus server: --prometheus_port flag not set"
316320
)
317321

322+
if multi_sampling and lora_input_adapters_path:
323+
raise ValueError("LoRA adapters is not enabled for multi_sampling mode.")
324+
318325
driver = create_driver(
319326
config,
320327
devices,

jetstream/engine/engine_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def generate(
211211
params: Params,
212212
decode_state: DecodeState,
213213
sampler: Optional[Callable[[Any], Any]] = None,
214-
lora_params: Params = None,
215214
) -> Tuple[DecodeState, ResultTokens]:
216215
"""Generates tokens for each sequence being decoded in parallel.
217216

0 commit comments

Comments
 (0)