Skip to content

Commit f7115c8

Browse files
committed
1) Fixed linter errors.
2) Removed hard-coded scale factor for batch processing. Also added functionality to support different scale_factor values for different adapters.
1 parent 3ff6383 commit f7115c8

5 files changed

Lines changed: 81 additions & 26 deletions

File tree

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,26 @@ def __init__(
121121
0 # Number of async tasks which are in "loading" state
122122
)
123123
self.decoding_adapters_cache: Dict[str, Any] = {}
124+
125+
# TODO: Make dtype configurable for the scale factor array
126+
self.adapters_scale_factor = jnp.empty(1, dtype=jnp.bfloat16)
127+
124128
self.total_slots = total_slots
125129
self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety
126130

131+
def _get_adapter_scale_factor(self, adapter_id: str):
132+
"""
133+
Internal: Get the LoRA scale_factor using the adapter_id.
134+
"""
135+
adapter_config = self.adapter_registry[adapter_id].config
136+
lora_scale_factor = float(1)
137+
138+
if "r" in adapter_config and "lora_alpha" in adapter_config:
139+
lora_rank = int(adapter_config["r"])
140+
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
141+
142+
return lora_scale_factor
143+
127144
# --- Unsafe Internal methods which assumes that lock is held ---
128145
def _unsafe_transfer_to_hbm(self, adapter_id: str):
129146
"""
@@ -212,67 +229,87 @@ def _unsafe_unload_adapter(self, adapter_id: str):
212229

213230
def _initialize_decoding_adapters_cache(self, adapter_weights):
214231
"""
215-
Create a new PyTree with zero tensors at the paths corresponding to non-None leaves
216-
in the input PyTree. The zero tensors have an added dimension of size `self.totol_slots`.
232+
Create a new PyTree with zero tensors at the paths corresponding to
233+
non-None leaves in the input PyTree. The zero tensors have an added
234+
dimension of size `self.totol_slots`.
217235
Args:
218236
adatper_weights: The input PyTree, whose structure will be mirrored.
219237
Returns:
220-
A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree.
238+
A new PyTree with zero Tensors or None values, mirroring the structure
239+
of the input PyTree.
221240
"""
241+
222242
def create_zero_leaf(leaf):
223243
if leaf is not None:
224244
original_shape = leaf.shape
225-
if not original_shape: # handle scalar case
245+
if not original_shape: # handle scalar case
226246
zero_tensor_shape = (self.total_slots,)
227247
else:
228-
zero_tensor_shape = (self.total_slots,) + original_shape # Prepend a new dimension
248+
zero_tensor_shape = (
249+
self.total_slots,
250+
) + original_shape # Prepend a new dimension
229251

230252
return jnp.zeros(zero_tensor_shape, dtype=leaf.dtype)
231253
else:
232-
return None # Maintain None structure for None leaves
254+
return None # Maintain None structure for None leaves
233255

256+
self.adapters_scale_factor = jnp.ones(self.total_slots, dtype=jnp.bfloat16)
234257
return jax.tree_util.tree_map(create_zero_leaf, adapter_weights)
235258

236-
237259
def insert_adapter_in_cache(self, adapter_id: str, slot_id: int):
238260
"""
239-
Insert the specific adapter tensors into a slot in the serving_adapters_cache.
261+
Insert the specific adapter tensors into a slot in the
262+
serving_adapters_cache.
240263
Args:
241264
adapter_id: The id of the adapter, whose tensors will be inserted
242-
slot_id: The id of slot, which represents the index in the serving_adapter_cache
243-
where the adapter tensors will be inserted.
265+
slot_id: The id of slot, which represents the index in the
266+
serving_adapter_cache where the adapter tensors will be inserted.
244267
"""
245268

246269
def insert_leaf(dest_leaf, source_leaf):
247270
if dest_leaf is not None and source_leaf is not None:
248-
return dest_leaf.at[slot_id].set(source_leaf) # Insert at the specific index
271+
return dest_leaf.at[slot_id].set(
272+
source_leaf
273+
) # Insert at the specific index
249274
elif dest_leaf is not None:
250-
return dest_leaf # If source_leaf is None, keep the zero_leaf as is
251-
elif source_leaf is not None: # In this case the adapters have different target modules
275+
return dest_leaf # If source_leaf is None, keep the zero_leaf as is
276+
elif (
277+
source_leaf is not None
278+
): # In this case the adapters have different target modules
252279
original_shape = source_leaf.shape
253-
if not original_shape: # Handle scalar case
280+
if not original_shape: # Handle scalar case
254281
zero_tensor_shape = (self.total_slots,)
255282
else:
256283
zero_tensor_shape = (self.total_slots,) + original_shape
257284
new_dest_leaf = jnp.zeros(zero_tensor_shape, dtype=source_leaf.dtype)
258285
return new_dest_leaf.at[slot_id].set(source_leaf)
259286
else:
260-
return None # If both are None, return None
287+
return None # If both are None, return None
261288

262289
if adapter_id == "":
263-
logging.info("Empty adapter id. So no LoRA tensors inserted into the cache in adapter_tensorStore.")
290+
logging.info(
291+
"Empty adapter id. No LoRA tensors added to adapter_tensorstore cache"
292+
)
264293
return
265294

266295
asyncio.run(self.load_adapter(adapter_id, None, True))
267296

268297
adapter_weights = self.loaded_adapters_hbm[adapter_id]
269298

270299
if not self.decoding_adapters_cache:
271-
self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(adapter_weights)
300+
self.decoding_adapters_cache = self._initialize_decoding_adapters_cache(
301+
adapter_weights
302+
)
272303

273-
self.decoding_adapters_cache = jax.tree_util.tree_map(insert_leaf,
274-
self.decoding_adapters_cache,
275-
adapter_weights)
304+
adapter_scale_factor = jnp.bfloat16(
305+
self._get_adapter_scale_factor(adapter_id)
306+
)
307+
self.adapters_scale_factor = self.adapters_scale_factor.at[slot_id].set(
308+
adapter_scale_factor
309+
)
310+
self.decoding_adapters_cache = jax.tree_util.tree_map(
311+
insert_leaf, self.decoding_adapters_cache, adapter_weights
312+
)
276313

277314
# --- Public Methods (Acquire lock, then call unsafe methods) ---
278315

jetstream/core/orchestrator.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,8 @@ def _insert_if_possible(
11061106

11071107
if adapter_tensorstore:
11081108
adapter_tensorstore.insert_adapter_in_cache(
1109-
new_request.adapter_id, slot)
1109+
new_request.adapter_id, slot
1110+
)
11101111

11111112
ThreadDebugLog(
11121113
thread_name,
@@ -1316,11 +1317,23 @@ def _generate_thread(self, idx: int):
13161317
my_slots.qsize() < max_concurrent_decodes
13171318
), "At this point we must have some requests inserted into the slots."
13181319

1319-
decoding_adapters_cache = None
1320-
13211320
if adapter_tensorstore:
1322-
decoding_adapters_cache = adapter_tensorstore.decoding_adapters_cache
1323-
#decode_state["lora_adapter_cache"] = decoding_adapters_cache
1321+
decoding_adapters_params = adapter_tensorstore.decoding_adapters_cache
1322+
adapters_scale_factor = adapter_tensorstore.adapters_scale_factor
1323+
b = adapters_scale_factor.shape[0]
1324+
1325+
# Reshaped the scale_factors array to 4-D to align with shape of
1326+
# the vectors `(batch, hidden_size, num_heads, head_dim)`.
1327+
reshaped_scale_factors = adapters_scale_factor.reshape((b, 1, 1, 1))
1328+
1329+
lora_state = {}
1330+
lora_state["scale_factor"] = reshaped_scale_factors
1331+
lora_state["lora_params"] = decoding_adapters_params
1332+
1333+
if isinstance(decode_state, dict): # For flax.struct.dataclass
1334+
decode_state["lora_state"] = lora_state
1335+
else: # For standard mutable dataclasses.dataclass
1336+
decode_state = decode_state.replace(lora_state=lora_state)
13241337

13251338
# Now we actually take a generate step on requests in the slots.
13261339
decode_state, sampled_tokens = generate_engine.generate(

jetstream/engine/mock_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
import functools
3434
from dataclasses import asdict
35-
from typing import Any, Callable, Optional, Tuple
35+
from typing import Any, Dict, Callable, Optional, Tuple
3636

3737
import jax
3838
import jax.numpy as jnp
@@ -71,6 +71,7 @@ class DecodeState:
7171
generate_cache_index: int
7272
generate_lengths: jax.Array
7373
generate_tokens: jax.Array
74+
lora_state: Optional[Dict[str, Any]] = None
7475

7576

7677
class TestEngine(engine_api.Engine):
@@ -509,6 +510,7 @@ def init_decode_state(self) -> DecodeState:
509510
generate_tokens=jnp.zeros(
510511
(self.generate_cache_batch, 1), dtype=jnp.int32
511512
),
513+
lora_state={},
512514
)
513515

514516
@property

jetstream/tests/core/lora/test_adapter_tensorstore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ async def asyncSetUp(self):
145145
adapters_dir_path=self.adapters_dir_path,
146146
hbm_memory_budget=self.hbm_budget,
147147
cpu_memory_budget=self.cpu_budget,
148+
total_slots=8,
148149
)
149150

150151
# Pre-register adapters for most tests to simplify setup

jetstream/tests/core/test_orchestrator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,15 @@ async def _setup_driver_with_adapterstore(
123123
adapters_dir_path="/tmp/",
124124
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
125125
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
126+
total_slots=8,
126127
)
127128

128129
generate_adapterstore = adapterstore.AdapterTensorStore(
129130
engine=generate_engine,
130131
adapters_dir_path="/tmp/",
131132
hbm_memory_budget=20 * (1024**3), # 20 GB HBM
132133
cpu_memory_budget=100 * (1024**3), # 100 GB RAM
134+
total_slots=8,
133135
)
134136

135137
await prefill_adapterstore.register_adapter(

0 commit comments

Comments
 (0)