Skip to content

Commit 5c0de6d

Browse files
committed
hacky way to test aot in jetstream
1 parent 9cb7785 commit 5c0de6d

1 file changed

Lines changed: 74 additions & 11 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@
102102
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
103103
import numpy as np
104104

105-
log_level = os.getenv("LOG_LEVEL", "WARNING").upper()
105+
from jax.experimental import layout as jax_layout
106+
DLL = jax_layout.DeviceLocalLayout
107+
Layout = jax_layout.Layout
108+
109+
log_level = os.getenv("LOG_LEVEL", "DEBUG").upper()
106110

107111
logger = logging.getLogger("JetstreamLogger")
108112
logger.propagate = False
@@ -405,6 +409,63 @@ def __init__(
405409

406410
self._jax_padding = jax_padding
407411

412+
##### hacky code using auto layout for interleaved engine
413+
self.engine = self._generate_engines[0]
414+
self.params = self._generate_params[0]
415+
logger.debug("Compiling generate function")
416+
self._generate_executable, self.params, self._decode_state_executable = self.engine.aot_compile(
417+
self.params, pass_rng_shape=False
418+
)
419+
self.decode_state = self._decode_state_executable(None)
420+
421+
# prefill
422+
interesting_buckets = [
423+
64,
424+
128,
425+
256,
426+
512,
427+
1024,
428+
]
429+
430+
self._cached_prefill = {}
431+
self._cached_insert = {}
432+
for length in interesting_buckets:
433+
i32_scalar = jax.ShapeDtypeStruct((), int)
434+
logger.debug("Compiling prefill: %d", length)
435+
input_data = jax.ShapeDtypeStruct((length,), jax.numpy.dtype("int32"))
436+
437+
self._cached_prefill[length] = (
438+
jax.jit(
439+
self.engine.prefill_aot,
440+
in_shardings=(self.engine.param_layouts, None, None),
441+
out_shardings=(Layout(DLL.AUTO), Layout(DLL.AUTO)),
442+
).lower(self.params, input_data, i32_scalar)
443+
).compile(compiler_options=None)
444+
445+
logger.debug("Generate dummy prefix: %d", length)
446+
dummy_tokens = jax.numpy.ones(shape=(length,), dtype=jax.numpy.dtype("int32"))
447+
prefix_shapes = jax.eval_shape(self.engine.prefill_aot, self.params, dummy_tokens, 1)
448+
449+
logger.debug("Compiling insert: %d", length)
450+
prefill_output_layout, _ = self._cached_prefill[length].output_layouts
451+
logger.debug("Prefill output layout: {}".format(prefill_output_layout))
452+
logger.debug("Prefix shapes: {}".format(prefix_shapes))
453+
i32_scalar = jax.ShapeDtypeStruct((), int)
454+
self._cached_insert[length] = (
455+
jax.jit(
456+
self.engine.insert,
457+
in_shardings=(prefill_output_layout, self.engine.decode_state_layouts, None),
458+
out_shardings=(self.engine.decode_state_layouts),
459+
donate_argnames=("decode_state"),
460+
).lower(prefix_shapes[0], self.engine.decode_state_shapes, i32_scalar)
461+
).compile(compiler_options=None)
462+
463+
self._prefill_engines[0] = self.engine
464+
self._generate_engines[0] = self.engine
465+
self._prefill_params[0] = self.params
466+
self._generate_params[0] = self.params
467+
468+
408469
# Create all threads
409470
self._prefill_threads = [
410471
JetThread(
@@ -759,10 +820,11 @@ def _prefill_thread(self, idx: int):
759820
)
760821
else:
761822
# Compute new kv cache for the prefill_content.
762-
prefill_result, first_token = prefill_engine.prefill(
763-
params=final_prefill_params,
764-
padded_tokens=padded_tokens,
765-
true_length=true_length,
823+
assert padded_tokens.shape[0] in self._cached_prefill
824+
prefill_result, first_token = self._cached_prefill[padded_tokens.shape[0]](
825+
final_prefill_params,
826+
padded_tokens,
827+
true_length,
766828
)
767829

768830
request.complete = np.zeros(
@@ -967,10 +1029,11 @@ def _insert_if_possible(
9671029
else:
9681030
break
9691031

970-
decode_state = generate_engine.insert(
1032+
length = new_request.prefill_result['cache']['decoder']['layers_0']['self_attention']['KVCache_0']['cache_prefill_segment_id'].value.shape[1]
1033+
decode_state = self._cached_insert[length](
9711034
new_request.prefill_result,
9721035
decode_state,
973-
slot=slot,
1036+
slot,
9741037
# request_id=new_request.request_id,
9751038
)
9761039
ThreadDebugLog(
@@ -1115,9 +1178,9 @@ def _generate_thread(self, idx: int):
11151178
# Keep track of what step tokens were generated at.
11161179
generate_timestep = 0
11171180
# State to store things like running kv cache in.
1118-
decode_state = generate_engine.init_decode_state()
1119-
1181+
decode_state = self.decode_state
11201182
generate_params = self._generate_params[idx]
1183+
11211184
thread_name = f"Generate thread {idx}"
11221185
ThreadDebugLog(thread_name, f"Generate params {idx} loaded.")
11231186
time_of_last_generate = time.time()
@@ -1178,8 +1241,8 @@ def _generate_thread(self, idx: int):
11781241
), "At this point we must have some requests inserted into the slots."
11791242

11801243
# Now we actually take a generate step on requests in the slots.
1181-
decode_state, sampled_tokens = generate_engine.generate(
1182-
generate_params, decode_state
1244+
decode_state, sampled_tokens = self._generate_executable(
1245+
generate_params, decode_state, None
11831246
)
11841247
sampled_tokens.copy_to_host_async()
11851248
# Respond to detokenization backpressure.

0 commit comments

Comments
 (0)