|
20 | 20 |
|
21 | 21 | from maxtext.configs import pyconfig |
22 | 22 | from maxtext.configs.pyconfig import resolve_config_path, _CONFIG_FILE_MAPPING, _module_from_path |
| 23 | +from maxtext.input_pipeline import data_processing_utils |
23 | 24 | from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR |
24 | 25 | from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path |
25 | 26 |
|
@@ -173,6 +174,70 @@ def test_identical_override_allowed(self): |
173 | 174 | ) |
174 | 175 | self.assertEqual(config.tokenizer_type, "huggingface") |
175 | 176 |
|
| 177 | + def test_list_config_coercion(self): |
| 178 | + """Verifies that string/tuple inputs for list[str] config fields are coerced to lists.""" |
| 179 | + # Case 1: Plain string (coerced to single-item list) |
| 180 | + config = pyconfig.initialize( |
| 181 | + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], |
| 182 | + skip_jax_distributed_system=True, |
| 183 | + train_data_columns="messages", |
| 184 | + ) |
| 185 | + self.assertEqual(config.train_data_columns, ["messages"]) |
| 186 | + |
| 187 | + # Case 2: Stringified list literal |
| 188 | + config = pyconfig.initialize( |
| 189 | + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], |
| 190 | + skip_jax_distributed_system=True, |
| 191 | + train_data_columns="['col1', 'col2']", |
| 192 | + ) |
| 193 | + self.assertEqual(config.train_data_columns, ["col1", "col2"]) |
| 194 | + |
| 195 | + # Case 3: Stringified list literal with whitespace |
| 196 | + config = pyconfig.initialize( |
| 197 | + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], |
| 198 | + skip_jax_distributed_system=True, |
| 199 | + train_data_columns="[ 'col1' , 'col2' ]", |
| 200 | + ) |
| 201 | + self.assertEqual(config.train_data_columns, ["col1", "col2"]) |
| 202 | + |
| 203 | + # Case 4: Stringified tuple literal |
| 204 | + config = pyconfig.initialize( |
| 205 | + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], |
| 206 | + skip_jax_distributed_system=True, |
| 207 | + train_data_columns="('col1', 'col2')", |
| 208 | + ) |
| 209 | + self.assertEqual(config.train_data_columns, ["col1", "col2"]) |
| 210 | + |
| 211 | + # Case 5: Real tuple value (passed via kwargs) |
| 212 | + config = pyconfig.initialize( |
| 213 | + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], |
| 214 | + skip_jax_distributed_system=True, |
| 215 | + train_data_columns=("col1", "col2"), |
| 216 | + ) |
| 217 | + self.assertEqual(config.train_data_columns, ["col1", "col2"]) |
| 218 | + |
| 219 | + # Case 6: Malformed stringified list (falls back to wrapping as single-item list) |
| 220 | + config = pyconfig.initialize( |
| 221 | + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], |
| 222 | + skip_jax_distributed_system=True, |
| 223 | + train_data_columns="[malformed, list", |
| 224 | + ) |
| 225 | + self.assertEqual(config.train_data_columns, ["[malformed, list"]) |
| 226 | + |
| 227 | + def test_coerced_list_is_validated_successfully(self): |
| 228 | + """Verifies that a coerced list from pyconfig is successfully validated by the dataset pipeline.""" |
| 229 | + # Simulate a user passing `train_data_columns=messages` on the CLI |
| 230 | + config = pyconfig.initialize( |
| 231 | + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], |
| 232 | + skip_jax_distributed_system=True, |
| 233 | + train_data_columns="messages", |
| 234 | + ) |
| 235 | + # Verify coercion to list was successful |
| 236 | + self.assertEqual(config.train_data_columns, ["messages"]) |
| 237 | + |
| 238 | + # Verify that passing this coerced list to the SFT column validator passes without error (Scenario A) |
| 239 | + data_processing_utils.validate_and_configure_sft_columns(config.train_data_columns, None) |
| 240 | + |
176 | 241 |
|
177 | 242 | if __name__ == "__main__": |
178 | 243 | unittest.main() |
0 commit comments