Skip to content

Commit 5f679a9

Browse files
committed
- Created separate adapter_tensorstore for each engine.
- Implemented unapply lora from base_params - Fixed some comments from the PR
1 parent bd67171 commit 5f679a9

11 files changed

Lines changed: 292 additions & 393 deletions

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import functools
2727
from typing import Dict, Optional, Any
2828
import numpy as np
29+
from jetstream.engine import engine_api
2930

3031

3132
def _get_size_of_pytree(params):
@@ -82,8 +83,14 @@ class AdapterTensorStore:
8283
"""
8384

8485

85-
def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
86+
def __init__(self,
87+
engine: engine_api.Engine,
88+
adapters_dir_path: str,
89+
hbm_memory_budget: int,
90+
cpu_memory_budget: int):
8691
"""Initializes the AdapterTensorStore."""
92+
self.engine = engine # Possibly MaxEngine object
93+
self.adapters_dir_path = adapters_dir_path.rstrip("/") # All Adapters path without trailing `/`
8794
self.hbm_memory_budget = hbm_memory_budget
8895
self.cpu_memory_budget = cpu_memory_budget
8996
self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters
@@ -95,26 +102,49 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
95102
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety
96103

97104

98-
def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, Any]):
105+
def register_adapter(self,
106+
adapter_id: str,
107+
adapter_path: str = None,
108+
adapter_config: Dict[str, Any] = None):
99109
"""Registers a new LoRA adatper."""
100110
"""
101-
Registers a LoRA adapter with the TensorStore. This does *not* load
102-
the adapter; it simply adds metadata about the adapter to the registry.
111+
Registers a LoRA adapter with the TensorStore. This also loads the adapter;
112+
IF called without adapter_config. Because in this case, it needs
113+
to get adapter_config from the engine's load_single_adapter() call, which
114+
also provides the adapter_params. So in that case it is beneficial to load
115+
the adapter to HBM. This call path is expected only from the direct inference
116+
request.
117+
OTHERWISE, it simply adds metadata about the adapter to the registry.
103118
104119
Args:
105120
adapter_id (str): A unique identifier for the adapter.
106121
adapter_path (str): The path to the adapter weights (file or directory).
107-
config (dict): Config of the loRA adapter.
122+
adapter_config (dict): Config of the loRA adapter.
108123
109124
Raises:
110125
ValueError: If an adapter with the same ID is already registered.
111126
"""
112127
if adapter_id in self.adapter_registry:
113-
raise ValueError(f"Adapter with ID '{adapter_id}' already registered.")
128+
logging.warning(f"Adapter with ID '{adapter_id}' already registered.")
129+
return
130+
131+
if adapter_path is None:
132+
adapter_path = f"{self.adapters_dir_path}/{adapter_id}"
133+
134+
adapter_params = None
135+
if adapter_config is None:
136+
adapter_params, adapter_config = self.engine.load_single_adapter(adapter_path)
137+
138+
if adapter_config is None:
139+
raise ValueError(f"Failed to read adapter_config from {adapter_path}")
140+
114141
self.adapter_registry[adapter_id] = AdapterMetadata(
115142
adapter_id=adapter_id,
116143
adapter_path=adapter_path,
117-
config=config)
144+
config=adapter_config)
145+
146+
if adapter_params is not None:
147+
asyncio.run(self.load_adapter(adapter_id, adapter_params, True))
118148

119149

120150
async def _transfer_to_hbm(self, adapter_id: str):
@@ -254,7 +284,10 @@ async def load_adapter(
254284

255285
try:
256286
if adapter_weights is None:
257-
raise ValueError("Adapter weights for adapter_id={adapter_id} is None.")
287+
adapter_weights, adapter_config = self.engine.load_single_adapter(adapter_path)
288+
289+
if adapter_weights is None:
290+
raise ValueError("Failed to load adapter_weights from {adapter_path}.")
258291

259292
async with self.lock: # Critical section for memory management
260293
adapter_weights_as_jnp_array = _as_jnp_array(adapter_weights)
@@ -303,21 +336,36 @@ async def load_adapter(
303336
self.running_requests -= 1
304337

305338

306-
def get_lora_config(self, adapter_id):
339+
def get_lora_config(self, adapter_id: str, load_if_not_loaded: bool = False):
307340
"""Getter for the LoRA adapter config."""
308341
metadata = self.adapter_registry.get(adapter_id)
342+
343+
if load_if_not_loaded and metadata is None:
344+
self.register_adapter(adapter_id)
345+
metadata = self.adapter_registry.get(adapter_id)
346+
347+
if metadata is None:
348+
raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.")
349+
309350
return metadata.config
310351

311352

312-
def get_lora_weights(self, adapter_id, to_hbm: bool = True):
353+
def get_lora_weights(self,
354+
adapter_id,
355+
to_hbm: bool = True,
356+
load_if_not_loaded: bool = False):
313357
"""Retrieves the unified LoRA parameters for the given adapter IDs.
314358
Handles HBM/CPU placement.
315359
"""
316360

317361
metadata = self.adapter_registry.get(adapter_id)
318362

363+
if load_if_not_loaded and metadata is None:
364+
self.register_adapter(adapter_id)
365+
metadata = self.adapter_registry.get(adapter_id)
366+
319367
if metadata is None:
320-
raise ValueError(f"Adapter with ID '{adapter_id}' not registered.")
368+
raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.")
321369

322370
if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu":
323371
asyncio.run(self.load_adapter(adapter_id, None, to_hbm)) # Start loading (async)

jetstream/core/lora/multi_lora_inference_api.py

Lines changed: 1 addition & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,11 @@
1717

1818
import logging
1919
import grpc
20-
import time
21-
import uuid
2220

23-
from typing import Any, AsyncIterator, Optional, Tuple, cast
21+
from typing import Optional
2422
from jetstream.core import orchestrator
25-
from jetstream.core.lora import adapter_tensorstore
2623
from jetstream.core.proto import multi_lora_decoding_pb2_grpc
2724
from jetstream.core.proto import multi_lora_decoding_pb2
28-
from jetstream.core.utils import async_multifuture
29-
from jetstream.core.utils.return_sample import ReturnSample
30-
from jetstream.engine import engine_api, tokenizer_api, token_utils
3125

3226

3327
class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer):
@@ -105,155 +99,3 @@ def unload_lora_adapter(
10599
logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}")
106100
return multi_lora_decoding_pb2.UnloadAdapterResponse(success=False, error_message=str(e))
107101

108-
109-
def _get_prefill_content(
110-
self, request: multi_lora_decoding_pb2.CompletionRequest
111-
) -> Tuple[str | list[int], bool]:
112-
which_content = request.WhichOneof("content")
113-
content = getattr(request, which_content)
114-
if which_content == "text_content":
115-
return cast(multi_lora_decoding_pb2.CompletionRequest.TextContent, content).text, False
116-
else:
117-
return (
118-
list(
119-
cast(multi_lora_decoding_pb2.CompletionRequest.TokenContent, content).token_ids
120-
),
121-
True,
122-
)
123-
124-
def process_client_side_tokenization_response(self, response: Any):
125-
samples = []
126-
for sample in response:
127-
samples.append(
128-
multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample(
129-
token_ids=sample.token_ids,
130-
)
131-
)
132-
return multi_lora_decoding_pb2.CompletionResponse(
133-
stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent(
134-
samples=samples
135-
)
136-
)
137-
138-
def should_buffer_response(self, response: Any) -> bool:
139-
for item in response:
140-
if item.text and token_utils.is_byte_token(item.text[-1]):
141-
# If any sample ends in bytes, this means we might still need to
142-
# decode more bytes to compose the string.
143-
return True
144-
145-
def process_server_side_tokenization_response(
146-
self, response: Any, buffered_response_list
147-
):
148-
# Flush the buffered responses to each sample of current response.
149-
current_response_with_flushed_buffer = list(
150-
zip(*buffered_response_list, response)
151-
)
152-
# Empty buffer: [[s0_cur], [s1_cur], ...]
153-
# Has buffer:
154-
# [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
155-
current_response_with_flushed_buffer = cast(
156-
list[list[ReturnSample]], current_response_with_flushed_buffer
157-
)
158-
# Form correct sample(s) and return as StreamContent for this iteration.
159-
samples = []
160-
for sample in current_response_with_flushed_buffer:
161-
text = []
162-
token_ids = []
163-
for resp in sample:
164-
text.extend(resp.text)
165-
token_ids.extend(resp.token_ids)
166-
samples.append(
167-
multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample(
168-
text=token_utils.text_tokens_to_str(text),
169-
token_ids=token_ids,
170-
)
171-
)
172-
return multi_lora_decoding_pb2.CompletionResponse(
173-
stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent(
174-
samples=samples
175-
)
176-
)
177-
178-
async def completions( # pylint: disable=invalid-overridden-method
179-
self,
180-
request: multi_lora_decoding_pb2.CompletionRequest,
181-
context: Optional[grpc.aio.ServicerContext] = None,
182-
) -> AsyncIterator[multi_lora_decoding_pb2.CompletionResponse]:
183-
184-
"""Decode."""
185-
if context is None:
186-
logging.warning(
187-
"LLM orchestrator is being used in offline test mode, and will not"
188-
" respond to gRPC queries - only direct function calls."
189-
)
190-
is_client_side_tokenization = False
191-
return_channel = async_multifuture.AsyncMultifuture()
192-
if context:
193-
context.add_done_callback(return_channel.cancel)
194-
195-
prefill_content, is_client_side_tokenization = self._get_prefill_content(
196-
request
197-
)
198-
199-
# Wrap request as an ActiveRequest.
200-
active_request = orchestrator.ActiveRequest(
201-
request_id=uuid.uuid4(),
202-
max_tokens=request.max_tokens,
203-
prefill_content=prefill_content,
204-
is_client_side_tokenization=is_client_side_tokenization,
205-
return_channel=return_channel,
206-
adapter_id=request.adapter_id,
207-
metadata=orchestrator.ActiveRequestMetadata(
208-
start_time=request.metadata.start_time,
209-
prefill_enqueue_time=time.perf_counter(),
210-
),
211-
)
212-
# The first stage is being prefilled, all other stages are handled
213-
# inside the driver (transfer, generate*N, detokenize).
214-
try:
215-
self._driver.place_request_on_prefill_queue(active_request)
216-
except queue.Full:
217-
# Safely abort the gRPC server thread with a retriable error.
218-
await _abort_or_raise(
219-
context=context,
220-
code=grpc.StatusCode.RESOURCE_EXHAUSTED,
221-
details=(
222-
"The driver prefill queue is full and more requests cannot be"
223-
" handled. You may retry this request."
224-
),
225-
)
226-
logging.info(
227-
"Placed request on the prefill queue.",
228-
)
229-
# When an active request is created a queue is instantiated. New tokens
230-
# are placed there during the decoding loop, we pop from that queue by
231-
# using the .next method on the active request.
232-
# Yielding allows for the response to be a streaming grpc call - which
233-
# can be called via iterating over a for loop on the client side.
234-
# The DecodeResponse stream should consume all generated tokens in
235-
# return_channel when complete signal is received (AsyncMultifuture
236-
# promises this).
237-
buffered_response_list = []
238-
async for response in active_request.return_channel:
239-
response = cast(list[ReturnSample], response)
240-
if is_client_side_tokenization:
241-
# If is_client_side_tokenization, the client should request with token
242-
# ids, and the JetStream server will return token ids as response.
243-
# The client should take care of tokenization and detokenization.
244-
yield self.process_client_side_tokenization_response(response)
245-
else:
246-
# Buffer response mechanism is used to handle streaming
247-
# detokenization with special character (For some edge cases with
248-
# SentencePiece tokenizer, it requires to decode a complete sequence
249-
# instead of a single token).
250-
if self.should_buffer_response(response):
251-
buffered_response_list.append(response)
252-
continue
253-
yield self.process_server_side_tokenization_response(
254-
response, buffered_response_list
255-
)
256-
# Reset buffer after flushed.
257-
buffered_response_list = []
258-
259-

0 commit comments

Comments
 (0)