|
102 | 102 | from jetstream.core.metrics.prometheus import JetstreamMetricsCollector |
103 | 103 | import numpy as np |
104 | 104 |
|
105 | | -log_level = os.getenv("LOG_LEVEL", "WARNING").upper() |
| 105 | +from jax.experimental import layout as jax_layout |
| 106 | +DLL = jax_layout.DeviceLocalLayout |
| 107 | +Layout = jax_layout.Layout |
| 108 | + |
| 109 | +log_level = os.getenv("LOG_LEVEL", "DEBUG").upper() |
106 | 110 |
|
107 | 111 | logger = logging.getLogger("JetstreamLogger") |
108 | 112 | logger.propagate = False |
@@ -405,6 +409,63 @@ def __init__( |
405 | 409 |
|
406 | 410 | self._jax_padding = jax_padding |
407 | 411 |
|
| 412 | + ##### hacky code using auto layout for interleaved engine |
| 413 | + self.engine = self._generate_engines[0] |
| 414 | + self.params = self._generate_params[0] |
| 415 | + logger.debug("Compiling generate function") |
| 416 | + self._generate_executable, self.params, self._decode_state_executable = self.engine.aot_compile( |
| 417 | + self.params, pass_rng_shape=False |
| 418 | + ) |
| 419 | + self.decode_state = self._decode_state_executable(None) |
| 420 | + |
| 421 | + # prefill |
| 422 | + interesting_buckets = [ |
| 423 | + 64, |
| 424 | + 128, |
| 425 | + 256, |
| 426 | + 512, |
| 427 | + 1024, |
| 428 | + ] |
| 429 | + |
| 430 | + self._cached_prefill = {} |
| 431 | + self._cached_insert = {} |
| 432 | + for length in interesting_buckets: |
| 433 | + i32_scalar = jax.ShapeDtypeStruct((), int) |
| 434 | + logger.debug("Compiling prefill: %d", length) |
| 435 | + input_data = jax.ShapeDtypeStruct((length,), jax.numpy.dtype("int32")) |
| 436 | + |
| 437 | + self._cached_prefill[length] = ( |
| 438 | + jax.jit( |
| 439 | + self.engine.prefill_aot, |
| 440 | + in_shardings=(self.engine.param_layouts, None, None), |
| 441 | + out_shardings=(Layout(DLL.AUTO), Layout(DLL.AUTO)), |
| 442 | + ).lower(self.params, input_data, i32_scalar) |
| 443 | + ).compile(compiler_options=None) |
| 444 | + |
| 445 | + logger.debug("Generate dummy prefix: %d", length) |
| 446 | + dummy_tokens = jax.numpy.ones(shape=(length,), dtype=jax.numpy.dtype("int32")) |
| 447 | + prefix_shapes = jax.eval_shape(self.engine.prefill_aot, self.params, dummy_tokens, 1) |
| 448 | + |
| 449 | + logger.debug("Compiling insert: %d", length) |
| 450 | + prefill_output_layout, _ = self._cached_prefill[length].output_layouts |
| 451 | + logger.debug("Prefill output layout: {}".format(prefill_output_layout)) |
| 452 | + logger.debug("Prefix shapes: {}".format(prefix_shapes)) |
| 453 | + i32_scalar = jax.ShapeDtypeStruct((), int) |
| 454 | + self._cached_insert[length] = ( |
| 455 | + jax.jit( |
| 456 | + self.engine.insert, |
| 457 | + in_shardings=(prefill_output_layout, self.engine.decode_state_layouts, None), |
| 458 | + out_shardings=(self.engine.decode_state_layouts), |
| 459 | + donate_argnames=("decode_state"), |
| 460 | + ).lower(prefix_shapes[0], self.engine.decode_state_shapes, i32_scalar) |
| 461 | + ).compile(compiler_options=None) |
| 462 | + |
| 463 | + self._prefill_engines[0] = self.engine |
| 464 | + self._generate_engines[0] = self.engine |
| 465 | + self._prefill_params[0] = self.params |
| 466 | + self._generate_params[0] = self.params |
| 467 | + |
| 468 | + |
408 | 469 | # Create all threads |
409 | 470 | self._prefill_threads = [ |
410 | 471 | JetThread( |
@@ -759,10 +820,11 @@ def _prefill_thread(self, idx: int): |
759 | 820 | ) |
760 | 821 | else: |
761 | 822 | # Compute new kv cache for the prefill_content. |
762 | | - prefill_result, first_token = prefill_engine.prefill( |
763 | | - params=final_prefill_params, |
764 | | - padded_tokens=padded_tokens, |
765 | | - true_length=true_length, |
| 823 | + assert padded_tokens.shape[0] in self._cached_prefill |
| 824 | + prefill_result, first_token = self._cached_prefill[padded_tokens.shape[0]]( |
| 825 | + final_prefill_params, |
| 826 | + padded_tokens, |
| 827 | + true_length, |
766 | 828 | ) |
767 | 829 |
|
768 | 830 | request.complete = np.zeros( |
@@ -967,10 +1029,11 @@ def _insert_if_possible( |
967 | 1029 | else: |
968 | 1030 | break |
969 | 1031 |
|
970 | | - decode_state = generate_engine.insert( |
| 1032 | + length = new_request.prefill_result['cache']['decoder']['layers_0']['self_attention']['KVCache_0']['cache_prefill_segment_id'].value.shape[1] |
| 1033 | + decode_state = self._cached_insert[length]( |
971 | 1034 | new_request.prefill_result, |
972 | 1035 | decode_state, |
973 | | - slot=slot, |
| 1036 | + slot, |
974 | 1037 | # request_id=new_request.request_id, |
975 | 1038 | ) |
976 | 1039 | ThreadDebugLog( |
@@ -1115,9 +1178,9 @@ def _generate_thread(self, idx: int): |
1115 | 1178 | # Keep track of what step tokens were generated at. |
1116 | 1179 | generate_timestep = 0 |
1117 | 1180 | # State to store things like running kv cache in. |
1118 | | - decode_state = generate_engine.init_decode_state() |
1119 | | - |
| 1181 | + decode_state = self.decode_state |
1120 | 1182 | generate_params = self._generate_params[idx] |
| 1183 | + |
1121 | 1184 | thread_name = f"Generate thread {idx}" |
1122 | 1185 | ThreadDebugLog(thread_name, f"Generate params {idx} loaded.") |
1123 | 1186 | time_of_last_generate = time.time() |
@@ -1178,8 +1241,8 @@ def _generate_thread(self, idx: int): |
1178 | 1241 | ), "At this point we must have some requests inserted into the slots." |
1179 | 1242 |
|
1180 | 1243 | # Now we actually take a generate step on requests in the slots. |
1181 | | - decode_state, sampled_tokens = generate_engine.generate( |
1182 | | - generate_params, decode_state |
| 1244 | + decode_state, sampled_tokens = self._generate_executable( |
| 1245 | + generate_params, decode_state, None |
1183 | 1246 | ) |
1184 | 1247 | sampled_tokens.copy_to_host_async() |
1185 | 1248 | # Respond to detokenization backpressure. |
|
0 commit comments