|
26 | 26 | from flax import nnx |
27 | 27 | from flax.linen import partitioning as nn_partitioning |
28 | 28 | from maxtext.configs import pyconfig |
29 | | -from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL |
30 | | -from maxtext.layers import quantizations |
| 29 | +from maxtext.common.common_types import MODEL_MODE_PREFILL |
31 | 30 |
|
32 | 31 | pytest.importorskip("jetstream", reason="jetstream not installed") |
33 | 32 | from maxtext.inference.maxengine import maxengine |
34 | | -from maxtext.models import models |
35 | 33 | from maxtext.utils import maxtext_utils |
36 | 34 | from maxtext.utils import model_creation_utils |
37 | 35 | from tests.utils.test_helpers import get_test_config_path |
@@ -71,17 +69,6 @@ def init_pyconfig(self, **kwargs): |
71 | 69 | ) |
72 | 70 | return config |
73 | 71 |
|
74 | | - def get_data(self): |
75 | | - s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length) |
76 | | - ids = jax.random.randint(self.rng, s, 0, self.cfg.vocab_size) |
77 | | - |
78 | | - decoder_segment_ids = jax.numpy.zeros(s) + DECODING_ACTIVE_SEQUENCE_INDICATOR |
79 | | - decoder_positions = jnp.stack( |
80 | | - [jnp.arange(self.cfg.max_target_length, dtype=jnp.int32) for _ in range(self.cfg.global_batch_size_to_train_on)] |
81 | | - ) |
82 | | - |
83 | | - return ids, decoder_segment_ids, decoder_positions |
84 | | - |
85 | 72 | def test_stack_and_unstack_prefill_cache(self): |
86 | 73 | config = pyconfig.initialize( |
87 | 74 | [None, get_test_config_path()], |
@@ -111,60 +98,8 @@ def test_stack_and_unstack_prefill_cache(self): |
111 | 98 | got_unstacked = engine._maybe_unstack_prefill_result_cache(got_stacked) |
112 | 99 | jax.tree.map(np.testing.assert_array_equal, got_unstacked, input_d) |
113 | 100 |
|
114 | | - def test_basic_prefill(self): |
115 | | - devices_array = maxtext_utils.create_device_mesh(self.cfg) |
116 | | - mesh = Mesh(devices_array, self.cfg.mesh_axes) |
117 | | - quant = quantizations.configure_quantization(self.cfg) |
118 | | - model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) |
119 | | - ids, decoder_segment_ids, decoder_positions = self.get_data() |
120 | | - |
121 | | - transformer_vars = model.init( |
122 | | - {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, |
123 | | - ids, |
124 | | - decoder_positions, |
125 | | - decoder_segment_ids, |
126 | | - enable_dropout=False, |
127 | | - ) |
128 | | - input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) |
129 | | - true_length = 4 |
130 | | - engine = maxengine.MaxEngine(self.cfg, jax.devices()) |
131 | | - prefill_result, first_token = engine.prefill( |
132 | | - params=transformer_vars, padded_tokens=input_tokens, true_length=true_length |
133 | | - ) |
134 | | - |
135 | | - self.assertEqual(prefill_result["generated_tokens"], jnp.array([0])) |
136 | | - # test default strategy is gready which choose only one next token |
137 | | - self.assertEqual(prefill_result["tokens"].size, 1) |
138 | | - self.assertNotEqual(prefill_result["tokens"], jnp.array([0])) |
139 | | - self.assertTrue(jnp.array_equal(first_token.data.size, 3)) |
140 | | - self.assertEqual(first_token.log_prob.shape, (1, 1)) |
141 | | - |
142 | | - def test_basic_decode(self): |
143 | | - devices_array = maxtext_utils.create_device_mesh(self.cfg) |
144 | | - mesh = Mesh(devices_array, self.cfg.mesh_axes) |
145 | | - quant = quantizations.configure_quantization(self.cfg) |
146 | | - model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) |
147 | | - ids, decoder_segment_ids, decoder_positions = self.get_data() |
148 | | - |
149 | | - transformer_vars = model.init( |
150 | | - {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, |
151 | | - ids, |
152 | | - decoder_positions, |
153 | | - decoder_segment_ids, |
154 | | - enable_dropout=False, |
155 | | - ) |
156 | | - input_tokens = jnp.array([1, 306, 5360, 304]) |
157 | | - engine = maxengine.MaxEngine(self.cfg, jax.devices()) |
158 | | - params = engine.load_params(params=transformer_vars) |
159 | | - decode_state = engine.init_decode_state() |
160 | | - prefill_result, _ = engine.prefill(params=params, padded_tokens=input_tokens, true_length=4) |
161 | | - decode_state = engine.insert(prefill_result, decode_state, slot=0) |
162 | | - decode_state, result_token = engine.generate(params=params, decode_state=decode_state) |
163 | | - |
164 | | - self.assertEqual(result_token.log_prob.ndim, 2) |
165 | | - self.assertEqual(result_token.log_prob.shape[1], 1) |
166 | | - self.assertEqual(result_token.data.ndim, 2) |
167 | | - self.assertEqual(result_token.data.shape[1], 3) |
| 101 | + # The Linen-path basic prefill/decode tests were removed when NNX became the |
| 102 | + # default. test_basic_prefill_nnx / test_basic_decode_nnx below cover the NNX path. |
168 | 103 |
|
169 | 104 | def _init_nnx_pyconfig(self, **kwargs): |
170 | 105 | """init_pyconfig with NNX flags on.""" |
|
0 commit comments