|
42 | 42 | DATA_CONFIG_MULTITURN_DATA_YAML, |
43 | 43 | DATA_CONFIG_RENAME_RETAIN_COLUMNS, |
44 | 44 | DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
| 45 | + DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER, |
45 | 46 | DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT, |
46 | 47 | DATA_CONFIG_YAML_STREAMING_PRETOKENIZED, |
47 | 48 | ) |
@@ -996,6 +997,52 @@ def test_run_training_with_pretokenised_dataset_containing_input_ids(): |
996 | 997 | assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference |
997 | 998 |
|
998 | 999 |
|
| 1000 | +def test_run_training_with_data_tokenized_using_tokenizer_handler(): |
| 1001 | + """Ensure that we can train on non tokenized dataset works by tokenizing using |
| 1002 | + tokenizer data handler via data config.""" |
| 1003 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1004 | + |
| 1005 | + data_args = copy.deepcopy(DATA_ARGS) |
| 1006 | + |
| 1007 | + # set training_data_path and response_template to none |
| 1008 | + data_args.response_template = None |
| 1009 | + data_args.training_data_path = None |
| 1010 | + |
| 1011 | + dataconfigfile = DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER |
| 1012 | + datapath = TWITTER_COMPLAINTS_DATA_JSONL |
| 1013 | + |
| 1014 | + # add data_paths in data_config file |
| 1015 | + with tempfile.NamedTemporaryFile( |
| 1016 | + "w", delete=False, suffix=".yaml" |
| 1017 | + ) as temp_yaml_file: |
| 1018 | + with open(dataconfigfile, "r", encoding="utf-8") as f: |
| 1019 | + data = yaml.safe_load(f) |
| 1020 | + datasets = data["datasets"] |
| 1021 | + for _, d in enumerate(datasets): |
| 1022 | + d["data_paths"] = [datapath] |
| 1023 | + yaml.dump(data, temp_yaml_file) |
| 1024 | + data_args.data_config_path = temp_yaml_file.name |
| 1025 | + |
| 1026 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1027 | + train_args.output_dir = tempdir |
| 1028 | + |
| 1029 | + sft_trainer.train(MODEL_ARGS, data_args, train_args) |
| 1030 | + |
| 1031 | + # validate full ft configs |
| 1032 | + _validate_training(tempdir) |
| 1033 | + checkpoint_path = _get_checkpoint_path(tempdir) |
| 1034 | + |
| 1035 | + # Load the model |
| 1036 | + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) |
| 1037 | + |
| 1038 | + # Run inference on the text |
| 1039 | + output_inference = loaded_model.run( |
| 1040 | + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 |
| 1041 | + ) |
| 1042 | + assert len(output_inference) > 0 |
| 1043 | + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference |
| 1044 | + |
| 1045 | + |
999 | 1046 | @pytest.mark.parametrize( |
1000 | 1047 | "dataset_path", |
1001 | 1048 | [CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN], |
|
0 commit comments