Skip to content

Commit bd67171

Browse files
committed
Adding documentations.
1 parent a6a5cd1 commit bd67171

3 files changed

Lines changed: 79 additions & 10 deletions

File tree

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,29 @@ class AdapterMetadata:
6666

6767

6868
class AdapterTensorStore:
69+
"""
70+
Manages the storage and retrieval of LoRA adapter weights, handling
71+
placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM.
72+
73+
This class implements an LRU (Least Recently Used) eviction policy
74+
to manage memory usage. It supports asynchronous loading and unloading
75+
of adapters to avoid blocking the main inference thread.
76+
77+
Args:
78+
hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for
79+
storing LoRA adapter weights.
80+
cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use
81+
for storing LoRA adapter weights.
82+
"""
83+
84+
6985
def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
86+
"""Initializes the AdapterTensorStore."""
7087
self.hbm_memory_budget = hbm_memory_budget
7188
self.cpu_memory_budget = cpu_memory_budget
7289
self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters
73-
self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> Unified LoRA params (in HBM)
74-
self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> Unified LoRA params (in CPU RAM)
90+
self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = {} # adapter_id -> LoRA params (in HBM)
91+
self.loaded_adapters_cpu: Dict[str, np.ndarray] = {} # adapter_id -> LoRA params (in CPU RAM)
7592
self.current_hbm_usage: int = 0
7693
self.current_cpu_usage: int = 0
7794
self.running_requests: int = 0 # Number of async tasks which are in "loading" state
@@ -80,6 +97,18 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int):
8097

8198
def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, Any]):
8299
"""Registers a new LoRA adatper."""
100+
"""
101+
Registers a LoRA adapter with the TensorStore. This does *not* load
102+
the adapter; it simply adds metadata about the adapter to the registry.
103+
104+
Args:
105+
adapter_id (str): A unique identifier for the adapter.
106+
adapter_path (str): The path to the adapter weights (file or directory).
107+
config (dict): Config of the loRA adapter.
108+
109+
Raises:
110+
ValueError: If an adapter with the same ID is already registered.
111+
"""
83112
if adapter_id in self.adapter_registry:
84113
raise ValueError(f"Adapter with ID '{adapter_id}' already registered.")
85114
self.adapter_registry[adapter_id] = AdapterMetadata(
@@ -162,17 +191,36 @@ async def load_adapter(
162191
self,
163192
adapter_id: str,
164193
adapter_weights = None,
165-
to_hbm: bool = True,
166-
force_load: bool = False):
167-
"""Loads a LoRA adapter's weights, managing HBM and CPU memory."""
194+
to_hbm: bool = True):
195+
"""
196+
Loads a LoRA adapter's weights into memory (either HBM or CPU RAM).
197+
198+
This method is asynchronous to avoid blocking the main thread during
199+
potentially slow I/O operations. It handles:
200+
- Checking if the adapter is already loaded.
201+
- Checking if there's enough memory (and evicting if necessary).
202+
- Loading the weights (in a separate thread).
203+
- Updating the adapter's status and metadata.
204+
205+
Args:
206+
adapter_id (str): The ID of the adapter to load.
207+
adapter_weights: In the form of a PyTree.
208+
to_hbm (bool): Whether to load the adapter into HBM (True) or
209+
CPU RAM (False). Defaults to True (HBM).
210+
211+
Raises:
212+
ValueError: If the adapter ID is not registered.
213+
RuntimeError: If there is not enough memory to load the adapter,
214+
and eviction fails to free up enough space.
215+
"""
168216

169217
if adapter_id not in self.adapter_registry:
170218
raise ValueError(f"Adapter with ID '{adapter_id}' not registered.")
171219

172220
metadata = self.adapter_registry[adapter_id]
173221

174222
async with self.lock: # Acquire lock for thread safety
175-
if not force_load and metadata.status in ("loaded_hbm", "loaded_cpu"):
223+
if metadata.status in ("loaded_hbm", "loaded_cpu"):
176224
metadata.last_accessed = time.time()
177225

178226
# if already loaded in HBM and we want HBM, or
@@ -195,7 +243,7 @@ async def load_adapter(
195243
await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting
196244

197245
# Make recursive call to load_adapter to copy to device
198-
await self.load_adapter(adapter_id, adapter_weights, to_hbm, force_load)
246+
await self.load_adapter(adapter_id, adapter_weights, to_hbm)
199247
return
200248

201249
metadata.status = "loading"

jetstream/core/lora/multi_lora_inference_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer):
34-
"""Manages the parameters of multiple lora requests and their lifelines."""
34+
"""Manages the parameters of multiple lora requests and their status/lifetimes."""
3535

3636
_driver: orchestrator.Driver
3737

jetstream/tools/multi_adapter_service_client.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""A test request."""
15+
"""A gRPC client to interact with JetStream Server."""
1616

1717
from typing import Sequence
1818

@@ -55,7 +55,28 @@
5555

5656

5757
def main(argv: Sequence[str]) -> None:
58-
del argv
58+
"""
59+
Main function for a gRPC client that interacts with a JetStream server.
60+
61+
This client can:
62+
- Load a LoRA adapter.
63+
- Unload a LoRA adapter.
64+
- List loaded adapters and their metadata.
65+
- Generate text completions (using LoRA adapters if specified).
66+
67+
The client uses command-line flags to specify the server address, port,
68+
text input, maximum number of tokens, adapter ID, adapter path, and the
69+
API to call. It uses insecure gRPC channels (suitable for local testing).
70+
71+
Args:
72+
argv: Command-line arguments (not used directly, flags are used instead).
73+
74+
Raises:
75+
ValueError: For invalid configurations, like missing required parameters
76+
for specific API calls.
77+
"""
78+
79+
del argv # Unused
5980
# Note: Uses insecure_channel only for local testing. Please add grpc
6081
# credentials for Production.
6182
address = f"{_SERVER.value}:{_PORT.value}"

0 commit comments

Comments
 (0)