|
38 | 38 | from scripts.run_inference import TunedCausalLM |
39 | 39 | from tests.artifacts.predefined_data_configs import ( |
40 | 40 | DATA_CONFIG_DUPLICATE_COLUMNS, |
| 41 | + DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE, |
41 | 42 | DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
42 | 43 | DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER, |
43 | 44 | DATA_CONFIG_MULTITURN_DATA_YAML, |
|
46 | 47 | DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER, |
47 | 48 | DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
48 | 49 | DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER, |
| 50 | + DATA_CONFIG_VALID_BASE64_CHAT_TEMPLATE, |
49 | 51 | DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT, |
50 | 52 | DATA_CONFIG_YAML_STREAMING_PRETOKENIZED, |
| 53 | + GRANITE_3_1_B_CHAT_TEMPLATE, |
51 | 54 | ) |
52 | 55 | from tests.artifacts.testdata import ( |
53 | 56 | CHAT_DATA_MULTI_TURN, |
|
84 | 87 | DataHandlerConfig, |
85 | 88 | DataPreProcessorConfig, |
86 | 89 | DataSetConfig, |
| 90 | + load_and_validate_data_config, |
87 | 91 | ) |
88 | 92 | from tuning.data.data_handlers import ( |
89 | 93 | DataHandler, |
@@ -1391,6 +1395,25 @@ def test_run_chat_style_ft_using_dataconfig_for_chat_template( |
1391 | 1395 | assert 'Provide two rhyming words for the word "love"' in output_inference |
1392 | 1396 |
|
1393 | 1397 |
|
| 1398 | +def test_data_config_chat_template_as_base64(): |
| 1399 | + """Check that the chat_template specified as base64 is parsed correctly.""" |
| 1400 | + expected_chat_template_path = GRANITE_3_1_B_CHAT_TEMPLATE |
| 1401 | + with open(expected_chat_template_path, "r", encoding="utf-8") as f: |
| 1402 | + expected_chat_template = f.read() |
| 1403 | + data_config_path = DATA_CONFIG_VALID_BASE64_CHAT_TEMPLATE |
| 1404 | + assert os.path.isfile(data_config_path) |
| 1405 | + data_config = load_and_validate_data_config(data_config_path) |
| 1406 | + parsed_chat_template = data_config.dataprocessor.chat_template |
| 1407 | + assert parsed_chat_template is not None, "the chat_template wasn't parsed correctly" |
| 1408 | + assert ( |
| 1409 | + data_config.dataprocessor.chat_template == expected_chat_template |
| 1410 | + ), "the chat_template wasn't parsed correctly" |
| 1411 | + # -------------------------------------------- |
| 1412 | + with pytest.raises(ValueError): |
| 1413 | + data_config_path = DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE |
| 1414 | + data_config = load_and_validate_data_config(data_config_path) |
| 1415 | + |
| 1416 | + |
1394 | 1417 | @pytest.mark.parametrize( |
1395 | 1418 | "data_args", |
1396 | 1419 | [ |
|
0 commit comments