Skip to content

Commit 1059978

Browse files
committed
Adding more doc strings.
2 parents e4d22bf + bd67171 commit 1059978

3 files changed

Lines changed: 82 additions & 10 deletions

File tree

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,31 @@ class AdapterMetadata:
6767

6868
class AdapterTensorStore:
6969
def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots: int):
70+
"""
71+
Manages the storage and retrieval of LoRA adapter weights, handling
72+
placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM.
73+
74+
This class implements an LRU (Least Recently Used) eviction policy
75+
to manage memory usage. It supports asynchronous loading and unloading
76+
of adapters to avoid blocking the main inference thread.
77+
78+
This class also creates a unified_lora_weights of all the adapters which is being
79+
used at any time for decoding purposes. These unified weights allows the backend
80+
model to server multiple different LoRA adapters in a single batch.
81+
82+
Args:
83+
hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for
84+
storing LoRA adapter weights.
85+
cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use
86+
for storing LoRA adapter weights.
87+
total_slots: Number of generate slots. This is also equals to max_concurrent_decodes.
88+
"""
89+
7090
self.hbm_memory_budget = hbm_memory_budget
7191
self.cpu_memory_budget = cpu_memory_budget
7292
self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters
73-
self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> Unified LoRA params (in HBM)
74-
self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> Unified LoRA params (in CPU RAM)
93+
self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> LoRA params (in HBM)
94+
self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> LoRA params (in CPU RAM)
7595
self.current_hbm_usage: int = 0
7696
self.current_cpu_usage: int = 0
7797
self.running_requests: int = 0 # Number of async tasks which are in "loading" state
@@ -82,6 +102,18 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots:
82102

83103
def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, Any]):
84104
"""Registers a new LoRA adatper."""
105+
"""
106+
Registers a LoRA adapter with the TensorStore. This does *not* load
107+
the adapter; it simply adds metadata about the adapter to the registry.
108+
109+
Args:
110+
adapter_id (str): A unique identifier for the adapter.
111+
adapter_path (str): The path to the adapter weights (file or directory).
112+
config (dict): Config of the loRA adapter.
113+
114+
Raises:
115+
ValueError: If an adapter with the same ID is already registered.
116+
"""
85117
if adapter_id in self.adapter_registry:
86118
raise ValueError(f"Adapter with ID '{adapter_id}' already registered.")
87119
self.adapter_registry[adapter_id] = AdapterMetadata(
@@ -234,17 +266,36 @@ async def load_adapter(
234266
self,
235267
adapter_id: str,
236268
adapter_weights = None,
237-
to_hbm: bool = True,
238-
force_load: bool = False):
239-
"""Loads a LoRA adapter's weights, managing HBM and CPU memory."""
269+
to_hbm: bool = True):
270+
"""
271+
Loads a LoRA adapter's weights into memory (either HBM or CPU RAM).
272+
273+
This method is asynchronous to avoid blocking the main thread during
274+
potentially slow I/O operations. It handles:
275+
- Checking if the adapter is already loaded.
276+
- Checking if there's enough memory (and evicting if necessary).
277+
- Loading the weights (in a separate thread).
278+
- Updating the adapter's status and metadata.
279+
280+
Args:
281+
adapter_id (str): The ID of the adapter to load.
282+
adapter_weights: In the form of a PyTree.
283+
to_hbm (bool): Whether to load the adapter into HBM (True) or
284+
CPU RAM (False). Defaults to True (HBM).
285+
286+
Raises:
287+
ValueError: If the adapter ID is not registered.
288+
RuntimeError: If there is not enough memory to load the adapter,
289+
and eviction fails to free up enough space.
290+
"""
240291

241292
if adapter_id not in self.adapter_registry:
242293
raise ValueError(f"Adapter with ID '{adapter_id}' not registered.")
243294

244295
metadata = self.adapter_registry[adapter_id]
245296

246297
async with self.lock: # Acquire lock for thread safety
247-
if not force_load and metadata.status in ("loaded_hbm", "loaded_cpu"):
298+
if metadata.status in ("loaded_hbm", "loaded_cpu"):
248299
metadata.last_accessed = time.time()
249300

250301
# if already loaded in HBM and we want HBM, or
@@ -267,7 +318,7 @@ async def load_adapter(
267318
await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting
268319

269320
# Make recursive call to load_adapter to copy to device
270-
await self.load_adapter(adapter_id, adapter_weights, to_hbm, force_load)
321+
await self.load_adapter(adapter_id, adapter_weights, to_hbm)
271322
return
272323

273324
metadata.status = "loading"

jetstream/core/lora/multi_lora_inference_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer):
34-
"""Manages the parameters of multiple lora requests and their lifelines."""
34+
"""Manages the parameters of multiple lora requests and their status/lifetimes."""
3535

3636
_driver: orchestrator.Driver
3737

jetstream/tools/multi_adapter_service_client.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""A test request."""
15+
"""A gRPC client to interact with JetStream Server."""
1616

1717
from typing import Sequence
1818

@@ -55,7 +55,28 @@
5555

5656

5757
def main(argv: Sequence[str]) -> None:
58-
del argv
58+
"""
59+
Main function for a gRPC client that interacts with a JetStream server.
60+
61+
This client can:
62+
- Load a LoRA adapter.
63+
- Unload a LoRA adapter.
64+
- List loaded adapters and their metadata.
65+
- Generate text completions (using LoRA adapters if specified).
66+
67+
The client uses command-line flags to specify the server address, port,
68+
text input, maximum number of tokens, adapter ID, adapter path, and the
69+
API to call. It uses insecure gRPC channels (suitable for local testing).
70+
71+
Args:
72+
argv: Command-line arguments (not used directly, flags are used instead).
73+
74+
Raises:
75+
ValueError: For invalid configurations, like missing required parameters
76+
for specific API calls.
77+
"""
78+
79+
del argv # Unused
5980
# Note: Uses insecure_channel only for local testing. Please add grpc
6081
# credentials for Production.
6182
address = f"{_SERVER.value}:{_PORT.value}"

0 commit comments

Comments
 (0)