Skip to content

Commit 453db81

Browse files
committed
Fixing some failures due to merge conflicts.
1 parent 2c2850b commit 453db81

2 files changed

Lines changed: 22 additions & 14 deletions

File tree

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ class AdapterTensorStore:
8787
model to server multiple different LoRA adapters in a single batch.
8888
8989
Args:
90+
engine: Engine corresponding to the adapter tensorstore
91+
adapters_dir_path: GCS path storing all the adapters
9092
hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for
9193
storing LoRA adapter weights.
9294
cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use
9395
for storing LoRA adapter weights.
94-
total_slots: Number of generate slots. This is also equals to max_concurrent_decodes.
9596
"""
9697

9798
def __init__(self,
@@ -111,7 +112,7 @@ def __init__(self,
111112
self.current_cpu_usage: int = 0
112113
self.running_requests: int = 0 # Number of async tasks which are in "loading" state
113114
self.decoding_adapters_cache: Dict[str, Any] = {}
114-
self.total_slots = total_slots
115+
self.total_slots = engine.max_concurrent_decodes
115116
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety
116117

117118

jetstream/core/orchestrator.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,6 @@ def __init__(
304304
self._metrics_collector = metrics_collector
305305
self._multi_sampling = multi_sampling
306306

307-
total_slots = 0
308-
for engine in self._generate_engines:
309-
total_slots += engine.max_concurrent_decodes
310-
311-
self._adapter_tensorstore = adapter_tensorstore.AdapterTensorStore(
312-
hbm_memory_budget=(20 * (1024 ** 3)), # 20 GB HBM
313-
cpu_memory_budget=(100 * (1024 ** 3)), # 100 GB RAM
314-
total_slots=total_slots)
315-
316307
# Stages 1-4 represent the life cycle of a request.
317308
# Stage 1
318309
# At first, a request is placed here in order to get prefilled.
@@ -562,7 +553,8 @@ def _export_lora_request_info(self):
562553
if self._metrics_collector:
563554
for idx, engine in enumerate(self._generate_engines):
564555
max_loras += engine.max_concurrent_decodes
565-
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
556+
if (self._generate_adapterstore and
557+
idx < len(self._generate_adapterstore)):
566558
adapters_list_str += asyncio.run(
567559
self._generate_adapterstore[idx].get_hbm_loaded_adapters())
568560

@@ -908,6 +900,10 @@ def _insert_if_possible(
908900
# Check if there are any free my_slots. We don't want to block here since
909901
# we can still generate if we can't insert. We do this in a while loop to
910902
# insert as many sequences as possible.
903+
adapter_tensorstore = None
904+
if self._generate_adapterstore:
905+
adapter_tensorstore = self._generate_adapterstore[idx]
906+
911907
while True:
912908
my_slots_size = my_slots.qsize()
913909

@@ -979,7 +975,9 @@ def _insert_if_possible(
979975
#request_id=new_request.request_id,
980976
)
981977

982-
self._adapter_tensorstore.insert_adapter_in_cache(new_request.adapter_id, slot)
978+
if adapter_tensorstore:
979+
adapter_tensorstore.insert_adapter_in_cache(
980+
new_request.adapter_id, slot)
983981

984982
ThreadDebugLog(
985983
thread_name,
@@ -1120,6 +1118,10 @@ def _generate_thread(self, idx: int):
11201118
my_generate_backlog = self._generate_backlogs[idx]
11211119
my_detokenize_backlog = self._detokenize_backlogs[idx]
11221120

1121+
adapter_tensorstore = None
1122+
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
1123+
adapter_tensorstore = self._generate_adapterstore[idx]
1124+
11231125
# Keep track of what step tokens were generated at.
11241126
generate_timestep = 0
11251127
# State to store things like running kv cache in.
@@ -1185,9 +1187,14 @@ def _generate_thread(self, idx: int):
11851187
my_slots.qsize() < max_concurrent_decodes
11861188
), "At this point we must have some requests inserted into the slots."
11871189

1190+
decoding_adapters_cache = None
1191+
1192+
if adapter_tensorstore:
1193+
decoding_adapters_cache = adapter_tensorstore.decoding_adapters_cache
1194+
11881195
# Now we actually take a generate step on requests in the slots.
11891196
decode_state, sampled_tokens = generate_engine.generate(
1190-
generate_params, decode_state, self._adapter_tensorstore.decoding_adapters_cache,
1197+
generate_params, decode_state, decoding_adapters_cache
11911198
)
11921199
sampled_tokens.copy_to_host_async()
11931200
# Respond to detokenization backpressure.

0 commit comments

Comments
 (0)