Skip to content

Commit fe1511c

Browse files
committed
Changes to resolve comments on the PR.
1 parent 6ce324c commit fe1511c

3 files changed

Lines changed: 38 additions & 36 deletions

File tree

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from typing import Dict, Optional, Any
2828
import numpy as np
2929
from jetstream.engine import engine_api
30+
from enum import Enum
3031

3132

3233
def _get_size_of_pytree(params):
@@ -54,12 +55,18 @@ def convert_if_np(leaf):
5455

5556
return jax.tree_util.tree_map(convert_if_np, params)
5657

58+
class AdapterStatus(str, Enum):
59+
UNLOADED = "unloaded"
60+
LOADING = "loading"
61+
LOADED_HBM = "loaded_hbm"
62+
LOADED_CPU = "loaded_cpu"
63+
5764

5865
@dataclasses.dataclass
5966
class AdapterMetadata:
6067
adapter_id: str
6168
adapter_path: str
62-
status: str = "unloaded" # "loaded_hbm", "loaded_cpu", "loading", "unloading"
69+
status: AdapterStatus = AdapterStatus.UNLOADED
6370
size_hbm: int = 0 # Size in HBM (bytes)
6471
size_cpu: int = 0 # Size in CPU RAM (bytes)
6572
last_accessed: float = 0.0 # timestamp
@@ -155,7 +162,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
155162
async with self.lock: #Acquire lock
156163
metadata = self.adapter_registry[adapter_id]
157164

158-
if metadata.status == "loaded_hbm":
165+
if metadata.status == AdapterStatus.LOADED_HBM:
159166
return
160167

161168
# Check if we have enough space in HBM; evict if necessary
@@ -172,7 +179,7 @@ async def _transfer_to_hbm(self, adapter_id: str):
172179
self.current_cpu_usage -= metadata.size_cpu
173180
self.current_hbm_usage += metadata.size_hbm
174181

175-
metadata.status = "loaded_hbm"
182+
metadata.status = AdapterStatus.LOADED_HBM
176183
metadata.last_accessed = time.time()
177184

178185

@@ -185,7 +192,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
185192
async with self.lock:
186193
metadata = self. adapter_registry[adapter_id]
187194

188-
if metadata.status == "loaded_cpu":
195+
if metadata.status == AdapterStatus.LOADED_CPU:
189196
return
190197

191198
# Check if we have enough space in CPU; evict if necessary.
@@ -200,7 +207,7 @@ async def _transfer_to_cpu(self, adapter_id: str):
200207
self.current_hbm_usage -= metadata.size_hbm
201208
self.current_cpu_usage += metadata.size_cpu
202209

203-
metadata.status = "loaded_cpu"
210+
metadata.status = AdapterStatus.LOADED_CPU
204211
metadata.last_accessed = time.time()
205212

206213

@@ -211,7 +218,7 @@ async def get_hbm_loaded_adapters(self):
211218

212219
async with self.lock:
213220
for adapter_id, metadata in self.adapter_registry.items():
214-
if metadata.status == "loaded_hbm":
221+
if metadata.status == AdapterStatus.LOADED_HBM:
215222
hbm_loaded_adapters.append(adapter_id)
216223

217224
return ", ".join(hbm_loaded_adapters)
@@ -250,33 +257,33 @@ async def load_adapter(
250257
metadata = self.adapter_registry[adapter_id]
251258

252259
async with self.lock: # Acquire lock for thread safety
253-
if metadata.status in ("loaded_hbm", "loaded_cpu"):
260+
if metadata.status in (AdapterStatus.LOADED_HBM, AdapterStatus.LOADED_CPU):
254261
metadata.last_accessed = time.time()
255262

256263
# if already loaded in HBM and we want HBM, or
257264
# already loaded in CPU and we want CPU, we're done.
258-
if ((to_hbm and metadata.status == "loaded_hbm") or
259-
not to_hbm and metadata.status == "loaded_cpu"):
265+
if ((to_hbm and metadata.status == AdapterStatus.LOADED_HBM) or
266+
not to_hbm and metadata.status == AdapterStatus.LOADED_CPU):
260267
return
261-
elif to_hbm and metadata.status == "loaded_cpu":
268+
elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU:
262269
# Transfer from cpu to hbm
263270
self._transfer_to_hbm(adapter_id)
264271
return
265-
elif not to_hbm and metadata.status == "loaded_hbm":
272+
elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM:
266273
# Transfer from hbm to cpu
267274
self._transfer_to_cpu(adapter_id)
268275
return
269276

270-
if metadata.status == "loading":
277+
if metadata.status == AdapterStatus.LOADING:
271278
# Wait untill loading is done.
272-
while metadata.status == "loading":
279+
while metadata.status == AdapterStatus.LOADING:
273280
await asyncio.sleep(0.1) # Short sleep to avoid busy-waiting
274281

275282
# Make recursive call to load_adapter to copy to device
276283
await self.load_adapter(adapter_id, adapter_weights, to_hbm)
277284
return
278285

279-
metadata.status = "loading"
286+
metadata.status = AdapterStatus.LOADING
280287
self.running_requests += 1
281288

282289
# Load the adapter (asynchronous)
@@ -319,18 +326,18 @@ async def load_adapter(
319326
if to_hbm:
320327
self.loaded_adapters_hbm[adapter_id] = adapter_weights_as_jnp_array # Convert the PyTree to Jax Array
321328
self.current_hbm_usage += adapter_size_hbm
322-
metadata.status = "loaded_hbm"
329+
metadata.status = AdapterStatus.LOADED_HBM
323330

324331
else: #to cpu
325332
self.loaded_adapters_cpu[adapter_id] = adapter_weights_as_np_array # Convert the PyTree to NumPy Array
326333
self.current_cpu_usage += adapter_size_cpu
327-
metadata.status = "loaded_cpu"
334+
metadata.status = AdapterStatus.LOADED_CPU
328335

329336
metadata.last_accessed = time.time()
330337

331338
except Exception as e:
332339
async with self.lock:
333-
metadata.status = "unloaded" # Mark as unloaded on error
340+
metadata.status = AdapterStatus.UNLOADED # Mark as unloaded on error
334341
raise e # Re-Raise the exception
335342
finally:
336343
async with self.lock:
@@ -368,11 +375,11 @@ def get_lora_weights(self,
368375
if metadata is None:
369376
raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.")
370377

371-
if metadata.status != "loaded_hbm" and metadata.status != "loaded_cpu":
378+
if metadata.status != AdapterStatus.LOADED_HBM and metadata.status != AdapterStatus.LOADED_CPU:
372379
asyncio.run(self.load_adapter(adapter_id, None, to_hbm)) # Start loading (async)
373-
elif to_hbm and metadata.status == "loaded_cpu":
380+
elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU:
374381
asyncio.run(self._transfer_to_hbm(adapter_id))
375-
elif not to_hbm and metadata.status == "loaded_hbm":
382+
elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM:
376383
asyncio.run(self._transfer_to_cpu(adapter_id))
377384

378385
# Wait till all the running requests are completed
@@ -397,21 +404,21 @@ async def unload_adapter(self, adapter_id: str):
397404
metadata = self.adapter_registry[adapter_id]
398405

399406
async with self.lock:
400-
if metadata.status == "unloaded":
407+
if metadata.status == AdapterStatus.UNLOADED:
401408
return # Already unloaded
402-
if metadata.status == "loading":
409+
if metadata.status == AdapterStatus.LOADING:
403410
# Wait for the loading to get complete.
404-
while metadata.status == "loading":
411+
while metadata.status == AdapterStatus.LOADING:
405412
await asyncio.sleep(0.1)
406413

407-
if metadata.status == "loaded_hbm":
414+
if metadata.status == AdapterStatus.LOADED_HBM:
408415
del self.loaded_adapters_hbm[adapter_id]
409416
self.current_hbm_usage -= metadata.size_hbm
410-
metadata.status = "unloaded"
411-
elif metadata.status == "loaded_cpu":
417+
metadata.status = AdapterStatus.UNLOADED
418+
elif metadata.status == AdapterStatus.LOADED_CPU:
412419
del self.loaded_adapters_cpu[adapter_id]
413420
self.current_cpu_usage -= metadata.size_cpu
414-
metadata.status = "unloaded"
421+
metadata.status = AdapterStatus.UNLOADED
415422

416423
metadata.last_accessed = time.time() # Unload time
417424
metadata.size_hbm = 0
@@ -431,7 +438,7 @@ def _evict(self, from_hbm: bool = True) -> bool:
431438
lru_time = float('inf')
432439

433440
for adapter_id, metadata in self.adapter_registry.items():
434-
if metadata.status == "loaded_hbm" if from_hbm else metadata.status == "loaded_cpu":
441+
if metadata.status == AdapterStatus.LOADED_HBM if from_hbm else metadata.status == AdapterStatus.LOADED_CPU:
435442
if metadata.last_accessed < lru_time:
436443
lru_time = metadata.last_accessed
437444
lru_adapter_id = adapter_id

jetstream/core/orchestrator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,8 +1370,8 @@ def _detokenize_thread(self, idx: int):
13701370

13711371
def load_adapter_to_tensorstore(
13721372
self,
1373-
adapter_id,
1374-
adapter_path):
1373+
adapter_id: str,
1374+
adapter_path: str):
13751375
"""Load the adapter to adapter_tensorstore for each engine."""
13761376
logger.info("Loading adapter_id=%s from %s.",
13771377
adapter_id, adapter_path)
@@ -1429,7 +1429,7 @@ def load_adapter_to_tensorstore(
14291429

14301430
def unload_adapter_from_tensorstore(
14311431
self,
1432-
adapter_id):
1432+
adapter_id: str):
14331433
"""Unload the adapter from adapter_tensorstore of each engine."""
14341434
logger.info("Unloading adapter_id=%s", adapter_id)
14351435

jetstream/core/server_lib.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,6 @@ def run(
266266
Returns:
267267
JetStreamServer that wraps the grpc server and orchestrator driver.
268268
"""
269-
# TODO: Deleting the lora_input_adapters_path for now.
270-
# Planning to use it in next big PR. Currently accomodating it
271-
# to fix the params mismatch between maxText and JetStream
272-
del lora_input_adapters_path
273-
274269
server_start_time = time.time()
275270
logging.info("Kicking off gRPC server.")
276271
# Setup Prometheus server

0 commit comments

Comments
 (0)