Skip to content

Commit 2c2850b

Browse files
committed
Merge amangu-lora to amangu-lora-3 branch
2 parents 1059978 + f0da2b9 commit 2c2850b

15 files changed

Lines changed: 388 additions & 459 deletions

experimental/jax/inference/config/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
class ModelId:
2121
llama_2_7b_chat_hf = "meta-llama/Llama-2-7b-chat-hf"
22+
llama_2_70b_chat_hf = "meta-llama/Llama-2-70b-chat-hf"
2223

2324

2425
@dataclasses.dataclass
@@ -43,6 +44,15 @@ class Config:
4344
page_size=128,
4445
hbm_utilization=0.875,
4546
),
47+
ModelId.llama_2_70b_chat_hf: InferenceParams(
48+
model_id=ModelId.llama_2_70b_chat_hf,
49+
batch_size=100,
50+
max_seq_length=2048,
51+
max_input_length=1024,
52+
prefill_chunk_sizes=[128, 256, 512, 1024],
53+
page_size=128,
54+
hbm_utilization=0.875,
55+
),
4656
}
4757

4858
@classmethod

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import functools
2727
from typing import Dict, Optional, Any
2828
import numpy as np
29+
from jetstream.engine import engine_api
30+
from enum import Enum
2931

3032

3133
def _get_size_of_pytree(params):
@@ -53,20 +55,25 @@ def convert_if_np(leaf):
5355

5456
return jax.tree_util.tree_map(convert_if_np, params)
5557

58+
class AdapterStatus(str, Enum):
59+
UNLOADED = "unloaded"
60+
LOADING = "loading"
61+
LOADED_HBM = "loaded_hbm"
62+
LOADED_CPU = "loaded_cpu"
63+
5664

5765
@dataclasses.dataclass
5866
class AdapterMetadata:
5967
adapter_id: str
6068
adapter_path: str
61-
status: str = "unloaded" # "loaded_hbm", "loaded_cpu", "loading", "unloading"
69+
status: AdapterStatus = AdapterStatus.UNLOADED
6270
size_hbm: int = 0 # Size in HBM (bytes)
6371
size_cpu: int = 0 # Size in CPU RAM (bytes)
6472
last_accessed: float = 0.0 # timestamp
6573
config: Dict[str, Any] = dataclasses.field(default_factory=dict)
6674

6775

6876
class AdapterTensorStore:
69-
def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots: int):
7077
"""
7178
Manages the storage and retrieval of LoRA adapter weights, handling
7279
placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM.
@@ -87,6 +94,14 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots:
8794
total_slots: Number of generate slots. This is also equals to max_concurrent_decodes.
8895
"""
8996

97+
def __init__(self,
98+
engine: engine_api.Engine,
99+
adapters_dir_path: str,
100+
hbm_memory_budget: int,
101+
cpu_memory_budget: int):
102+
"""Initializes the AdapterTensorStore."""
103+
self.engine = engine # Possibly MaxEngine object
104+
self.adapters_dir_path = adapters_dir_path.rstrip("/") # All Adapters path without trailing `/`
90105
self.hbm_memory_budget = hbm_memory_budget
91106
self.cpu_memory_budget = cpu_memory_budget
92107
self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters
@@ -100,26 +115,49 @@ def __init__(self, hbm_memory_budget: int, cpu_memory_budget: int, total_slots:
100115
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety
101116

102117

103-
def register_adapter(self, adapter_id: str, adapter_path: str, config: Dict[str, Any]):
118+
def register_adapter(self,
119+
adapter_id: str,
120+
adapter_path: str | None = None,
121+
adapter_config: Dict[str, Any] | None = None):
104122
"""Registers a new LoRA adatper."""
105123
"""
106-
Registers a LoRA adapter with the TensorStore. This does *not* load
107-
the adapter; it simply adds metadata about the adapter to the registry.
124+
Registers a LoRA adapter with the TensorStore. This also loads the adapter;
125+
IF called without adapter_config. Because in this case, it needs
126+
to get adapter_config from the engine's load_single_adapter() call, which
127+
also provides the adapter_params. So in that case it is beneficial to load
128+
the adapter to HBM. This call path is expected only from the direct inference
129+
request.
130+
OTHERWISE, it simply adds metadata about the adapter to the registry.
108131
109132
Args:
110133
adapter_id (str): A unique identifier for the adapter.
111134
adapter_path (str): The path to the adapter weights (file or directory).
112-
config (dict): Config of the loRA adapter.
135+
adapter_config (dict): Config of the loRA adapter.
113136
114137
Raises:
115138
ValueError: If an adapter with the same ID is already registered.
116139
"""
117140
if adapter_id in self.adapter_registry:
118-
raise ValueError(f"Adapter with ID '{adapter_id}' already registered.")
141+
logging.warning(f"Adapter with ID '{adapter_id}' already registered.")
142+
return
143+
144+
if adapter_path is None:
145+
adapter_path = f"{self.adapters_dir_path}/{adapter_id}"
146+
147+
adapter_params = None
148+
if adapter_config is None:
149+
adapter_params, adapter_config = self.engine.load_single_adapter(adapter_path)
150+
151+
if adapter_config is None:
152+
raise ValueError(f"Failed to read adapter_config from {adapter_path}")
153+
119154
self.adapter_registry[adapter_id] = AdapterMetadata(
120155
adapter_id=adapter_id,
121156
adapter_path=adapter_path,
122-
config=config)
157+
config=adapter_config)
158+
159+
if adapter_params is not None:
160+
asyncio.run(self.load_adapter(adapter_id, adapter_params, True))
123161

124162

125163
async def _transfer_to_hbm(self, adapter_id: str):
@@ -130,7 +168,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
130168
async with self.lock: #Acquire lock
131169
metadata = self.adapter_registry[adapter_id]
132170

133-
if metadata.status == "loaded_hbm":
171+
if metadata.status == AdapterStatus.LOADED_HBM:
134172
return
135173

136174
# Check if we have enough space in HBM; evict if necessary
@@ -147,7 +185,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
147185
self.current_cpu_usage -= metadata.size_cpu
148186
self.current_hbm_usage += metadata.size_hbm
149187

150-
metadata.status = "loaded_hbm"
188+
metadata.status = AdapterStatus.LOADED_HBM
151189
metadata.last_accessed = time.time()
152190

153191

@@ -160,7 +198,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
160198
async with self.lock:
161199
metadata = self. adapter_registry[adapter_id]
162200

163-
if metadata.status == "loaded_cpu":
201+
if metadata.status == AdapterStatus.LOADED_CPU:
164202
return
165203

166204
# Check if we have enough space in CPU; evict if necessary.
@@ -175,7 +213,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
175213
self.current_hbm_usage -= metadata.size_hbm
176214
self.current_cpu_usage += metadata.size_cpu
177215

178-
metadata.status = "loaded_cpu"
216+
metadata.status = AdapterStatus.LOADED_CPU
179217
metadata.last_accessed = time.time()
180218

181219

@@ -256,7 +294,7 @@ async def get_hbm_loaded_adapters(self):
256294

257295
async with self.lock:
258296
for adapter_id, metadata in self.adapter_registry.items():
259-
if metadata.status == "loaded_hbm":
297+
if metadata.status == AdapterStatus.LOADED_HBM:
260298
hbm_loaded_adapters.append(adapter_id)
261299

262300
return ", ".join(hbm_loaded_adapters)
@@ -295,41 +333,45 @@ async def load_adapter(
295333
metadata = self.adapter_registry[adapter_id]
296334

297335
async with self.lock: # Acquire lock for thread safety
298-
if metadata.status in ("loaded_hbm", "loaded_cpu"):
336+
if metadata.status in (AdapterStatus.LOADED_HBM, AdapterStatus.LOADED_CPU):
299337
metadata.last_accessed = time.time()
300338

301339
# if already loaded in HBM and we want HBM, or
302340
# already loaded in CPU and we want CPU, we're done.
303-
if ((to_hbm and metadata.status == "loaded_hbm") or
304-
not to_hbm and metadata.status == "loaded_cpu"):
341+
if ((to_hbm and metadata.status == AdapterStatus.LOADED_HBM) or
342+
not to_hbm and metadata.status == AdapterStatus.LOADED_CPU):
305343
return
306-
elif to_hbm and metadata.status == "loaded_cpu":
344+
elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU:
307345
# Transfer from cpu to hbm
308346
self._transfer_to_hbm(adapter_id)
309347
return
310-
elif not to_hbm and metadata.status == "loaded_hbm":
348+
elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM:
311349
# Transfer from hbm to cpu
312350
self._transfer_to_cpu(adapter_id)
313351
return
314352

315-
if metadata.status == "loading":
353+
if metadata.status == AdapterStatus.LOADING:
316354
# Wait untill loading is done.
317-
while metadata.status == "loading":
355+
while metadata.status == AdapterStatus.LOADING:
318356
await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting
319357

320358
# Make recursive call to load_adapter to copy to device
321359
await self.load_adapter(adapter_id, adapter_weights, to_hbm)
322360
return
323361

324-
metadata.status = "loading"
362+
metadata.status = AdapterStatus.LOADING
325363
self.running_requests += 1
326364

327365
# Load the adapter (asynchronous)
328366
loop = asyncio.get_running_loop()
329367

330368
try:
331369
if adapter_weights is None:
332-
raise ValueError("Adapter weights for adapter_id={adapter_id} is None.")
370+
adapter_path = f"{self.adapters_dir_path}/{adapter_id}"
371+
adapter_weights, adapter_config = self.engine.load_single_adapter(adapter_path)
372+
373+
if adapter_weights is None:
374+
raise ValueError("Failed to load adapter_weights from {adapter_path}.")
333375

334376
async with self.lock: # Critical section for memory management
335377
adapter_weights_as_jnp_array = _as_jnp_array(adapter_weights)
@@ -360,45 +402,60 @@ async def load_adapter(
360402
if to_hbm:
361403
self.loaded_adapters_hbm[adapter_id] = adapter_weights_as_jnp_array # Convert the PyTree to Jax Array
362404
self.current_hbm_usage += adapter_size_hbm
363-
metadata.status = "loaded_hbm"
405+
metadata.status = AdapterStatus.LOADED_HBM
364406

365407
else: #to cpu
366408
self.loaded_adapters_cpu[adapter_id] = adapter_weights_as_np_array # Convert the PyTree to NumPy Array
367409
self.current_cpu_usage += adapter_size_cpu
368-
metadata.status = "loaded_cpu"
410+
metadata.status = AdapterStatus.LOADED_CPU
369411

370412
metadata.last_accessed = time.time()
371413

372414
except Exception as e:
373415
async with self.lock:
374-
metadata.status = "unloaded" # Mark as unloaded on error
416+
metadata.status = AdapterStatus.UNLOADED # Mark as unloaded on error
375417
raise e # Re-Raise the exception
376418
finally:
377419
async with self.lock:
378420
self.running_requests -= 1
379421

380422

381-
def get_lora_config(self, adapter_id):
423+
def get_lora_config(self, adapter_id: str, load_if_not_loaded: bool = False):
382424
"""Getter for the LoRA adapter config."""
383425
metadata = self.adapter_registry.get(adapter_id)
426+
427+
if load_if_not_loaded and metadata is None:
428+
self.register_adapter(adapter_id)
429+
metadata = self.adapter_registry.get(adapter_id)
430+
431+
if metadata is None:
432+
raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.")
433+
384434
return metadata.config
385435

386436

387-
def get_lora_weights(self, adapter_id, to_hbm: bool = True):
437+
def get_lora_weights(self,
438+
adapter_id,
439+
to_hbm: bool = True,
440+
load_if_not_loaded: bool = False):
388441
"""Retrieves the unified LoRA parameters for the given adapter IDs.
389442
Handles HBM/CPU placement.
390443
"""
391444

392445
metadata = self.adapter_registry.get(adapter_id)
393446

447+
if load_if_not_loaded and metadata is None:
448+
self.register_adapter(adapter_id)
449+
metadata = self.adapter_registry.get(adapter_id)
450+
394451
if metadata is None:
395-
raise ValueError(f"Adapter with ID '{adapter_id}' not registered.")
452+
raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.")
396453

397-
if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu":
454+
if metadata.status != AdapterStatus.LOADED_HBM and metadata.status != AdapterStatus.LOADED_CPU:
398455
asyncio.run(self.load_adapter(adapter_id, None, to_hbm)) # Start loading (async)
399-
elif to_hbm and metadata.status == "loaded_cpu":
456+
elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU:
400457
asyncio.run(self._transfer_to_hbm(adapter_id))
401-
elif not to_hbm and metadata.status == "loaded_hbm":
458+
elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM:
402459
asyncio.run(self._transfer_to_cpu(adapter_id))
403460

404461
# Wait till all the running requests are completed
@@ -423,21 +480,21 @@ async def unload_adapter(self, adapter_id: str):
423480
metadata = self.adapter_registry[adapter_id]
424481

425482
async with self.lock:
426-
if metadata.status == "unloaded":
483+
if metadata.status == AdapterStatus.UNLOADED:
427484
return # Already unloaded
428-
if metadata.status == "loading":
485+
if metadata.status == AdapterStatus.LOADING:
429486
# Wait for the loading to get complete.
430-
while metadata.status == "loading":
487+
while metadata.status == AdapterStatus.LOADING:
431488
await asyncio.sleep(0.1)
432489

433-
if metadata.status == "loaded_hbm":
490+
if metadata.status == AdapterStatus.LOADED_HBM:
434491
del self.loaded_adapters_hbm[adapter_id]
435492
self.current_hbm_usage -= metadata.size_hbm
436-
metadata.status = "unloaded"
437-
elif metadata.status == "loaded_cpu":
493+
metadata.status = AdapterStatus.UNLOADED
494+
elif metadata.status == AdapterStatus.LOADED_CPU:
438495
del self.loaded_adapters_cpu[adapter_id]
439496
self.current_cpu_usage -= metadata.size_cpu
440-
metadata.status = "unloaded"
497+
metadata.status = AdapterStatus.UNLOADED
441498

442499
metadata.last_accessed = time.time() # Unload time
443500
metadata.size_hbm = 0
@@ -457,7 +514,7 @@ def _evict(self, from_hbm: bool = True) -> bool:
457514
lru_time = float('inf')
458515

459516
for adapter_id, metadata in self.adapter_registry.items():
460-
if metadata.status == "loaded_hbm" if from_hbm else metadata.status == "loaded_cpu":
517+
if metadata.status == AdapterStatus.LOADED_HBM if from_hbm else metadata.status == AdapterStatus.LOADED_CPU:
461518
if metadata.last_accessed < lru_time:
462519
lru_time = metadata.last_accessed
463520
lru_adapter_id = adapter_id

0 commit comments

Comments
 (0)