@@ -304,15 +304,6 @@ def __init__(
304304 self ._metrics_collector = metrics_collector
305305 self ._multi_sampling = multi_sampling
306306
307- total_slots = 0
308- for engine in self ._generate_engines :
309- total_slots += engine .max_concurrent_decodes
310-
311- self ._adapter_tensorstore = adapter_tensorstore .AdapterTensorStore (
312- hbm_memory_budget = (20 * (1024 ** 3 )), # 20 GB HBM
313- cpu_memory_budget = (100 * (1024 ** 3 )), # 100 GB RAM
314- total_slots = total_slots )
315-
316307 # Stages 1-4 represent the life cycle of a request.
317308 # Stage 1
318309 # At first, a request is placed here in order to get prefilled.
@@ -562,7 +553,8 @@ def _export_lora_request_info(self):
562553 if self ._metrics_collector :
563554 for idx , engine in enumerate (self ._generate_engines ):
564555 max_loras += engine .max_concurrent_decodes
565- if self ._generate_adapterstore and idx < len (self ._generate_adapterstore ):
556+ if (self ._generate_adapterstore and
557+ idx < len (self ._generate_adapterstore )):
566558 adapters_list_str += asyncio .run (
567559 self ._generate_adapterstore [idx ].get_hbm_loaded_adapters ())
568560
@@ -908,6 +900,10 @@ def _insert_if_possible(
908900 # Check if there are any free my_slots. We don't want to block here since
909901 # we can still generate if we can't insert. We do this in a while loop to
910902 # insert as many sequences as possible.
903+ adapter_tensorstore = None
904+ if self ._generate_adapterstore :
905+ adapter_tensorstore = self ._generate_adapterstore [idx ]
906+
911907 while True :
912908 my_slots_size = my_slots .qsize ()
913909
@@ -979,7 +975,9 @@ def _insert_if_possible(
979975 #request_id=new_request.request_id,
980976 )
981977
982- self ._adapter_tensorstore .insert_adapter_in_cache (new_request .adapter_id , slot )
978+ if adapter_tensorstore :
979+ adapter_tensorstore .insert_adapter_in_cache (
980+ new_request .adapter_id , slot )
983981
984982 ThreadDebugLog (
985983 thread_name ,
@@ -1120,6 +1118,10 @@ def _generate_thread(self, idx: int):
11201118 my_generate_backlog = self ._generate_backlogs [idx ]
11211119 my_detokenize_backlog = self ._detokenize_backlogs [idx ]
11221120
1121+ adapter_tensorstore = None
1122+ if self ._generate_adapterstore and idx < len (self ._generate_adapterstore ):
1123+ adapter_tensorstore = self ._generate_adapterstore [idx ]
1124+
11231125 # Keep track of what step tokens were generated at.
11241126 generate_timestep = 0
11251127 # State to store things like running kv cache in.
@@ -1185,9 +1187,14 @@ def _generate_thread(self, idx: int):
11851187 my_slots .qsize () < max_concurrent_decodes
11861188 ), "At this point we must have some requests inserted into the slots."
11871189
1190+ decoding_adapters_cache = None
1191+
1192+ if adapter_tensorstore :
1193+ decoding_adapters_cache = adapter_tensorstore .decoding_adapters_cache
1194+
11881195 # Now we actually take a generate step on requests in the slots.
11891196 decode_state , sampled_tokens = generate_engine .generate (
1190- generate_params , decode_state , self . _adapter_tensorstore . decoding_adapters_cache ,
1197+ generate_params , decode_state , decoding_adapters_cache
11911198 )
11921199 sampled_tokens .copy_to_host_async ()
11931200 # Respond to detokenization backpressure.
0 commit comments