diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index 9033c7f366..7026890845 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -43,6 +43,18 @@ DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER = os.path.join( PREDEFINED_DATA_CONFIGS, "mt_data_granite_3_1B_tokenize_and_mask_handler.yaml" ) +DATA_CONFIG_VALID_BASE64_CHAT_TEMPLATE = os.path.join( + PREDEFINED_DATA_CONFIGS, + "granite_3_1b_valid_base64_data_handler.yaml", +) +DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE = os.path.join( + PREDEFINED_DATA_CONFIGS, + "granite_3_1b_invalid_base64_data_handler.yaml", +) +GRANITE_3_1_B_CHAT_TEMPLATE = os.path.join( + PREDEFINED_DATA_CONFIGS, + "granite_3_1b_chat_template.txt", +) DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join( PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml" ) diff --git a/tests/artifacts/predefined_data_configs/granite_3_1b_chat_template.txt b/tests/artifacts/predefined_data_configs/granite_3_1b_chat_template.txt new file mode 100644 index 0000000000..487920c59e --- /dev/null +++ b/tests/artifacts/predefined_data_configs/granite_3_1b_chat_template.txt @@ -0,0 +1,49 @@ +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set system_message = "Knowledge Cutoff Date: April 2024.\nToday's Date: " + strftime_now('%B %d, %Y') + ".\nYou are Granite, developed by IBM." %} + {%- if tools and documents %} + {%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.\n\nWrite the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %} + {%- elif tools %} + {%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %} + {%- elif documents %} + {%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %} + {%- else %} + {%- set system_message = system_message + " You are a helpful AI assistant." %} + {%- endif %} + {%- if 'citations' in controls and documents %} + {%- set system_message = system_message + '\n\nIn your response, use the symbols and to indicate when a fact comes from a document in the search result, e.g 0 for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %} + {%- endif %} + {%- if 'hallucinations' in controls and documents %} + {%- set system_message = system_message + '\n\nFinally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %} + {%- endif %} + {%- set loop_messages = messages %} +{%- endif %} +{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>\n' }} +{%- if tools %} + {{- '<|start_of_role|>tools<|end_of_role|>' }} + {{- tools | tojson(indent=4) }} + {{- '<|end_of_text|>\n' }} +{%- endif %} +{%- if documents %} + {{- '<|start_of_role|>documents<|end_of_role|>' }} + {%- for document in documents %} + {{- 'Document ' + loop.index0 | string + '\n' }} + {{- document['text'] }} + {%- if not loop.last %} + {{- '\n\n'}} + {%- endif%} + {%- endfor %} + {{- '<|end_of_text|>\n' }} +{%- endif %} +{%- for message in loop_messages %} + {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }} + {%- if loop.last and add_generation_prompt %} + {{- '<|start_of_role|>assistant' }} + {%- if controls %} + {{- ' ' + controls | tojson()}} + {%- endif %} + {{- '<|end_of_role|>' }} + {%- endif %} +{%- endfor %} diff --git a/tests/artifacts/predefined_data_configs/granite_3_1b_invalid_base64_data_handler.yaml b/tests/artifacts/predefined_data_configs/granite_3_1b_invalid_base64_data_handler.yaml new file mode 100644 index 0000000000..e47578cbff --- /dev/null +++ b/tests/artifacts/predefined_data_configs/granite_3_1b_invalid_base64_data_handler.yaml @@ -0,0 +1,34 @@ +dataprocessor: + type: default + chat_template_base64: 'dsaeyUtIGlmIG1lc3NhZ2VzWzBdWydyb2xlJ10gPT0gJ3N5c3RlbScgJX0KICAgIHslLSBzZXQgc3lzdGVtX21lc3NhZ2UgPSBtZXNzYWdlc1swXVsnY29udGVudCddICV9CiAgICB7JS0gc2V0IGxvb3BfbWVzc2FnZXMgPSBtZXNzYWdlc1sxOl0gJX0KeyUtIGVsc2UgJX0KICAgIHslLSBzZXQgc3lzdGVtX21lc3NhZ2UgPSAiS25vd2xlZGdlIEN1dG9mZiBEYXRlOiBBcHJpbCAyMDI0LlxuVG9kYXkncyBEYXRlOiAiICsgc3RyZnRpbWVfbm93KCclQiAlZCwgJVknKSArICIuXG5Zb3UgYXJlIEdyYW5pdGUsIGRldmVsb3BlZCBieSBJQk0uIiAlfQogICAgeyUtIGlmIHRvb2xzIGFuZCBkb2N1bWVudHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAiIFlvdSBhcmUgYSBoZWxwZnVsIEFJIGFzc2lzdGFudCB3aXRoIGFjY2VzcyB0byB0aGUgZm9sbG93aW5nIHRvb2xzLiBXaGVuIGEgdG9vbCBpcyByZXF1aXJlZCB0byBhbnN3ZXIgdGhlIHVzZXIncyBxdWVyeSwgcmVzcG9uZCB3aXRoIDx8dG9vbF9jYWxsfD4gZm9sbG93ZWQgYnkgYSBKU09OIGxpc3Qgb2YgdG9vbHMgdXNlZC4gSWYgYSB0b29sIGRvZXMgbm90IGV4aXN0IGluIHRoZSBwcm92aWRlZCBsaXN0IG9mIHRvb2xzLCBub3RpZnkgdGhlIHVzZXIgdGhhdCB5b3UgZG8gbm90IGhhdmUgdGhlIGFiaWxpdHkgdG8gZnVsZmlsbCB0aGUgcmVxdWVzdC5cblxuV3JpdGUgdGhlIHJlc3BvbnNlIHRvIHRoZSB1c2VyJ3MgaW5wdXQgYnkgc3RyaWN0bHkgYWxpZ25pbmcgd2l0aCB0aGUgZmFjdHMgaW4gdGhlIHByb3ZpZGVkIGRvY3VtZW50cy4gSWYgdGhlIGluZm9ybWF0aW9uIG5lZWRlZCB0byBhbnN3ZXIgdGhlIHF1ZXN0aW9uIGlzIG5vdCBhdmFpbGFibGUgaW4gdGhlIGRvY3VtZW50cywgaW5mb3JtIHRoZSB1c2VyIHRoYXQgdGhlIHF1ZXN0aW9uIGNhbm5vdCBiZSBhbnN3ZXJlZCBiYXNlZCBvbiB0aGUgYXZhaWxhYmxlIGRhdGEuIiAlfQogICAgeyUtIGVsaWYgdG9vbHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAiIFlvdSBhcmUgYSBoZWxwZnVsIEFJIGFzc2lzdGFudCB3aXRoIGFjY2VzcyB0byB0aGUgZm9sbG93aW5nIHRvb2xzLiBXaGVuIGEgdG9vbCBpcyByZXF1aXJlZCB0byBhbnN3ZXIgdGhlIHVzZXIncyBxdWVyeSwgcmVzcG9uZCB3aXRoIDx8dG9vbF9jYWxsfD4gZm9sbG93ZWQgYnkgYSBKU09OIGxpc3Qgb2YgdG9vbHMgdXNlZC4gSWYgYSB0b29sIGRvZXMgbm90IGV4aXN0IGluIHRoZSBwcm92aWRlZCBsaXN0IG9mIHRvb2xzLCBub3RpZnkgdGhlIHVzZXIgdGhhdCB5b3UgZG8gbm90IGhhdmUgdGhlIGFiaWxpdHkgdG8gZnVsZmlsbCB0aGUgcmVxdWVzdC4iICV9CiAgICB7JS0gZWxpZiBkb2N1bWVudHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAiIFdyaXRlIHRoZSByZXNwb25zZSB0byB0aGUgdXNlcidzIGlucHV0IGJ5IHN0cmljdGx5IGFsaWduaW5nIHdpdGggdGhlIGZhY3RzIGluIHRoZSBwcm92aWRlZCBkb2N1bWVudHMuIElmIHRoZSBpbmZvcm1hdGlvbiBuZWVkZWQgdG8gYW5zd2VyIHRoZSBxdWVzdGlvbiBpcyBub3QgYXZhaWxhYmxlIGluIHRoZSBkb2N1bWVudHMsIGluZm9ybSB0aGUgdXNlciB0aGF0IHRoZSBxdWVzdGlvbiBjYW5ub3QgYmUgYW5zd2VyZWQgYmFzZWQgb24gdGhlIGF2YWlsYWJsZSBkYXRhLiIgJX0KICAgIHslLSBlbHNlICV9CiAgICAgICAgeyUtIHNldCBzeXN0ZW1fbWVzc2FnZSA9IHN5c3RlbV9tZXNzYWdlICsgIiBZb3UgYXJlIGEgaGVscGZ1bCBBSSBhc3Npc3RhbnQuIiAlfSAgICAKICAgIHslLSBlbmRpZiAlfQogICAgeyUtIGlmICdjaXRhdGlvbnMnIGluIGNvbnRyb2xzIGFuZCBkb2N1bWVudHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAnXG5cbkluIHlvdXIgcmVzcG9uc2UsIHVzZSB0aGUgc3ltYm9scyA8Y28+IGFuZCA8L2NvPiB0byBpbmRpY2F0ZSB3aGVuIGEgZmFjdCBjb21lcyBmcm9tIGEgZG9jdW1lbnQgaW4gdGhlIHNlYXJjaCByZXN1bHQsIGUuZyA8Y28+MDwvY28+IGZvciBhIGZhY3QgZnJvbSBkb2N1bWVudCAwLiBBZnRlcndhcmRzLCBsaXN0IGFsbCB0aGUgY2l0YXRpb25zIHdpdGggdGhlaXIgY29ycmVzcG9uZGluZyBkb2N1bWVudHMgaW4gYW4gb3JkZXJlZCBsaXN0LicgJX0KICAgIHslLSBlbmRpZiAlfQogICAgeyUtIGlmICdoYWxsdWNpbmF0aW9ucycgaW4gY29udHJvbHMgYW5kIGRvY3VtZW50cyAlfQogICAgICAgIHslLSBzZXQgc3lzdGVtX21lc3NhZ2UgPSBzeXN0ZW1fbWVzc2FnZSArICdcblxuRmluYWxseSwgYWZ0ZXIgdGhlIHJlc3BvbnNlIGlzIHdyaXR0ZW4sIGluY2x1ZGUgYSBudW1iZXJlZCBsaXN0IG9mIHNlbnRlbmNlcyBmcm9tIHRoZSByZXNwb25zZSB0aGF0IGFyZSBwb3RlbnRpYWxseSBoYWxsdWNpbmF0ZWQgYW5kIG5vdCBiYXNlZCBpbiB0aGUgZG9jdW1lbnRzLicgJX0KICAgIHslLSBlbmRpZiAlfQogICAgeyUtIHNldCBsb29wX21lc3NhZ2VzID0gbWVzc2FnZXMgJX0KeyUtIGVuZGlmICV9Cnt7LSAnPHxzdGFydF9vZl9yb2xlfD5zeXN0ZW08fGVuZF9vZl9yb2xlfD4nICsgc3lzdGVtX21lc3NhZ2UgKyAnPHxlbmRfb2ZfdGV4dHw+XG4nIH19CnslLSBpZiB0b29scyAlfQogICAge3stICc8fHN0YXJ0X29mX3JvbGV8PnRvb2xzPHxlbmRfb2Zfcm9sZXw+JyB9fQogICAge3stIHRvb2xzIHwgdG9qc29uKGluZGVudD00KSB9fQogICAge3stICc8fGVuZF9vZl90ZXh0fD5cbicgfX0KeyUtIGVuZGlmICV9CnslLSBpZiBkb2N1bWVudHMgJX0KICAgIHt7LSAnPHxzdGFydF9vZl9yb2xlfD5kb2N1bWVudHM8fGVuZF9vZl9yb2xlfD4nIH19CiAgICB7JS0gZm9yIGRvY3VtZW50IGluIGRvY3VtZW50cyAlfQogICAgICAgIHt7LSAnRG9jdW1lbnQgJyArIGxvb3AuaW5kZXgwIHwgc3RyaW5nICsgJ1xuJyB9fQogICAgICAgIHt7LSBkb2N1bWVudFsndGV4dCddIH19CiAgICAgICAgeyUtIGlmIG5vdCBsb29wLmxhc3QgJX0KICAgICAgICAgICAge3stICdcblxuJ319CiAgICAgICAgeyUtIGVuZGlmJX0KICAgIHslLSBlbmRmb3IgJX0KICAgIHt7LSAnPHxlbmRfb2ZfdGV4dHw+XG4nIH19CnslLSBlbmRpZiAlfQp7JS0gZm9yIG1lc3NhZ2UgaW4gbG9vcF9tZXNzYWdlcyAlfQogICAge3stICc8fHN0YXJ0X29mX3JvbGV8PicgKyBtZXNzYWdlWydyb2xlJ10gKyAnPHxlbmRfb2Zfcm9sZXw+JyArIG1lc3NhZ2VbJ2NvbnRlbnQnXSArICc8fGVuZF9vZl90ZXh0fD5cbicgfX0KICAgIHslLSBpZiBsb29wLmxhc3QgYW5kIGFkZF9nZW5lcmF0aW9uX3Byb21wdCAlfQogICAgICAgIHt7LSAnPHxzdGFydF9vZl9yb2xlfD5hc3Npc3RhbnQnIH19CiAgICAgICAgICAgIHslLSBpZiBjb250cm9scyAlfQogICAgICAgICAgICAgICAge3stICcgJyArIGNvbnRyb2xzIHwgdG9qc29uKCl9fQogICAgICAgICAgICB7JS0gZW5kaWYgJX0KICAgICAgICB7ey0gJzx8ZW5kX29mX3JvbGV8PicgfX0KICAgIHslLSBlbmRpZiAlfQp7JS0gZW5kZm9yICV9Cg==' +datasets: + - name: dataset_1 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_chat_template_with_masking + arguments: + remove_columns: all + fn_kwargs: + max_seq_length: 1024 + conversation_column: "messages" + - name: dataset_2 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_chat_template_with_masking + arguments: + remove_columns: all + fn_kwargs: + max_seq_length: 1024 + conversation_column: "messages" + - name: dataset_3 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_chat_template_with_masking + arguments: + remove_columns: all + fn_kwargs: + max_seq_length: 1024 + conversation_column: "messages" \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/granite_3_1b_valid_base64_data_handler.yaml b/tests/artifacts/predefined_data_configs/granite_3_1b_valid_base64_data_handler.yaml new file mode 100644 index 0000000000..6268a4650d --- /dev/null +++ b/tests/artifacts/predefined_data_configs/granite_3_1b_valid_base64_data_handler.yaml @@ -0,0 +1,34 @@ +dataprocessor: + type: default + chat_template_base64: 'eyUtIGlmIG1lc3NhZ2VzWzBdWydyb2xlJ10gPT0gJ3N5c3RlbScgJX0KICAgIHslLSBzZXQgc3lzdGVtX21lc3NhZ2UgPSBtZXNzYWdlc1swXVsnY29udGVudCddICV9CiAgICB7JS0gc2V0IGxvb3BfbWVzc2FnZXMgPSBtZXNzYWdlc1sxOl0gJX0KeyUtIGVsc2UgJX0KICAgIHslLSBzZXQgc3lzdGVtX21lc3NhZ2UgPSAiS25vd2xlZGdlIEN1dG9mZiBEYXRlOiBBcHJpbCAyMDI0LlxuVG9kYXkncyBEYXRlOiAiICsgc3RyZnRpbWVfbm93KCclQiAlZCwgJVknKSArICIuXG5Zb3UgYXJlIEdyYW5pdGUsIGRldmVsb3BlZCBieSBJQk0uIiAlfQogICAgeyUtIGlmIHRvb2xzIGFuZCBkb2N1bWVudHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAiIFlvdSBhcmUgYSBoZWxwZnVsIEFJIGFzc2lzdGFudCB3aXRoIGFjY2VzcyB0byB0aGUgZm9sbG93aW5nIHRvb2xzLiBXaGVuIGEgdG9vbCBpcyByZXF1aXJlZCB0byBhbnN3ZXIgdGhlIHVzZXIncyBxdWVyeSwgcmVzcG9uZCB3aXRoIDx8dG9vbF9jYWxsfD4gZm9sbG93ZWQgYnkgYSBKU09OIGxpc3Qgb2YgdG9vbHMgdXNlZC4gSWYgYSB0b29sIGRvZXMgbm90IGV4aXN0IGluIHRoZSBwcm92aWRlZCBsaXN0IG9mIHRvb2xzLCBub3RpZnkgdGhlIHVzZXIgdGhhdCB5b3UgZG8gbm90IGhhdmUgdGhlIGFiaWxpdHkgdG8gZnVsZmlsbCB0aGUgcmVxdWVzdC5cblxuV3JpdGUgdGhlIHJlc3BvbnNlIHRvIHRoZSB1c2VyJ3MgaW5wdXQgYnkgc3RyaWN0bHkgYWxpZ25pbmcgd2l0aCB0aGUgZmFjdHMgaW4gdGhlIHByb3ZpZGVkIGRvY3VtZW50cy4gSWYgdGhlIGluZm9ybWF0aW9uIG5lZWRlZCB0byBhbnN3ZXIgdGhlIHF1ZXN0aW9uIGlzIG5vdCBhdmFpbGFibGUgaW4gdGhlIGRvY3VtZW50cywgaW5mb3JtIHRoZSB1c2VyIHRoYXQgdGhlIHF1ZXN0aW9uIGNhbm5vdCBiZSBhbnN3ZXJlZCBiYXNlZCBvbiB0aGUgYXZhaWxhYmxlIGRhdGEuIiAlfQogICAgeyUtIGVsaWYgdG9vbHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAiIFlvdSBhcmUgYSBoZWxwZnVsIEFJIGFzc2lzdGFudCB3aXRoIGFjY2VzcyB0byB0aGUgZm9sbG93aW5nIHRvb2xzLiBXaGVuIGEgdG9vbCBpcyByZXF1aXJlZCB0byBhbnN3ZXIgdGhlIHVzZXIncyBxdWVyeSwgcmVzcG9uZCB3aXRoIDx8dG9vbF9jYWxsfD4gZm9sbG93ZWQgYnkgYSBKU09OIGxpc3Qgb2YgdG9vbHMgdXNlZC4gSWYgYSB0b29sIGRvZXMgbm90IGV4aXN0IGluIHRoZSBwcm92aWRlZCBsaXN0IG9mIHRvb2xzLCBub3RpZnkgdGhlIHVzZXIgdGhhdCB5b3UgZG8gbm90IGhhdmUgdGhlIGFiaWxpdHkgdG8gZnVsZmlsbCB0aGUgcmVxdWVzdC4iICV9CiAgICB7JS0gZWxpZiBkb2N1bWVudHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAiIFdyaXRlIHRoZSByZXNwb25zZSB0byB0aGUgdXNlcidzIGlucHV0IGJ5IHN0cmljdGx5IGFsaWduaW5nIHdpdGggdGhlIGZhY3RzIGluIHRoZSBwcm92aWRlZCBkb2N1bWVudHMuIElmIHRoZSBpbmZvcm1hdGlvbiBuZWVkZWQgdG8gYW5zd2VyIHRoZSBxdWVzdGlvbiBpcyBub3QgYXZhaWxhYmxlIGluIHRoZSBkb2N1bWVudHMsIGluZm9ybSB0aGUgdXNlciB0aGF0IHRoZSBxdWVzdGlvbiBjYW5ub3QgYmUgYW5zd2VyZWQgYmFzZWQgb24gdGhlIGF2YWlsYWJsZSBkYXRhLiIgJX0KICAgIHslLSBlbHNlICV9CiAgICAgICAgeyUtIHNldCBzeXN0ZW1fbWVzc2FnZSA9IHN5c3RlbV9tZXNzYWdlICsgIiBZb3UgYXJlIGEgaGVscGZ1bCBBSSBhc3Npc3RhbnQuIiAlfSAgICAKICAgIHslLSBlbmRpZiAlfQogICAgeyUtIGlmICdjaXRhdGlvbnMnIGluIGNvbnRyb2xzIGFuZCBkb2N1bWVudHMgJX0KICAgICAgICB7JS0gc2V0IHN5c3RlbV9tZXNzYWdlID0gc3lzdGVtX21lc3NhZ2UgKyAnXG5cbkluIHlvdXIgcmVzcG9uc2UsIHVzZSB0aGUgc3ltYm9scyA8Y28+IGFuZCA8L2NvPiB0byBpbmRpY2F0ZSB3aGVuIGEgZmFjdCBjb21lcyBmcm9tIGEgZG9jdW1lbnQgaW4gdGhlIHNlYXJjaCByZXN1bHQsIGUuZyA8Y28+MDwvY28+IGZvciBhIGZhY3QgZnJvbSBkb2N1bWVudCAwLiBBZnRlcndhcmRzLCBsaXN0IGFsbCB0aGUgY2l0YXRpb25zIHdpdGggdGhlaXIgY29ycmVzcG9uZGluZyBkb2N1bWVudHMgaW4gYW4gb3JkZXJlZCBsaXN0LicgJX0KICAgIHslLSBlbmRpZiAlfQogICAgeyUtIGlmICdoYWxsdWNpbmF0aW9ucycgaW4gY29udHJvbHMgYW5kIGRvY3VtZW50cyAlfQogICAgICAgIHslLSBzZXQgc3lzdGVtX21lc3NhZ2UgPSBzeXN0ZW1fbWVzc2FnZSArICdcblxuRmluYWxseSwgYWZ0ZXIgdGhlIHJlc3BvbnNlIGlzIHdyaXR0ZW4sIGluY2x1ZGUgYSBudW1iZXJlZCBsaXN0IG9mIHNlbnRlbmNlcyBmcm9tIHRoZSByZXNwb25zZSB0aGF0IGFyZSBwb3RlbnRpYWxseSBoYWxsdWNpbmF0ZWQgYW5kIG5vdCBiYXNlZCBpbiB0aGUgZG9jdW1lbnRzLicgJX0KICAgIHslLSBlbmRpZiAlfQogICAgeyUtIHNldCBsb29wX21lc3NhZ2VzID0gbWVzc2FnZXMgJX0KeyUtIGVuZGlmICV9Cnt7LSAnPHxzdGFydF9vZl9yb2xlfD5zeXN0ZW08fGVuZF9vZl9yb2xlfD4nICsgc3lzdGVtX21lc3NhZ2UgKyAnPHxlbmRfb2ZfdGV4dHw+XG4nIH19CnslLSBpZiB0b29scyAlfQogICAge3stICc8fHN0YXJ0X29mX3JvbGV8PnRvb2xzPHxlbmRfb2Zfcm9sZXw+JyB9fQogICAge3stIHRvb2xzIHwgdG9qc29uKGluZGVudD00KSB9fQogICAge3stICc8fGVuZF9vZl90ZXh0fD5cbicgfX0KeyUtIGVuZGlmICV9CnslLSBpZiBkb2N1bWVudHMgJX0KICAgIHt7LSAnPHxzdGFydF9vZl9yb2xlfD5kb2N1bWVudHM8fGVuZF9vZl9yb2xlfD4nIH19CiAgICB7JS0gZm9yIGRvY3VtZW50IGluIGRvY3VtZW50cyAlfQogICAgICAgIHt7LSAnRG9jdW1lbnQgJyArIGxvb3AuaW5kZXgwIHwgc3RyaW5nICsgJ1xuJyB9fQogICAgICAgIHt7LSBkb2N1bWVudFsndGV4dCddIH19CiAgICAgICAgeyUtIGlmIG5vdCBsb29wLmxhc3QgJX0KICAgICAgICAgICAge3stICdcblxuJ319CiAgICAgICAgeyUtIGVuZGlmJX0KICAgIHslLSBlbmRmb3IgJX0KICAgIHt7LSAnPHxlbmRfb2ZfdGV4dHw+XG4nIH19CnslLSBlbmRpZiAlfQp7JS0gZm9yIG1lc3NhZ2UgaW4gbG9vcF9tZXNzYWdlcyAlfQogICAge3stICc8fHN0YXJ0X29mX3JvbGV8PicgKyBtZXNzYWdlWydyb2xlJ10gKyAnPHxlbmRfb2Zfcm9sZXw+JyArIG1lc3NhZ2VbJ2NvbnRlbnQnXSArICc8fGVuZF9vZl90ZXh0fD5cbicgfX0KICAgIHslLSBpZiBsb29wLmxhc3QgYW5kIGFkZF9nZW5lcmF0aW9uX3Byb21wdCAlfQogICAgICAgIHt7LSAnPHxzdGFydF9vZl9yb2xlfD5hc3Npc3RhbnQnIH19CiAgICAgICAgICAgIHslLSBpZiBjb250cm9scyAlfQogICAgICAgICAgICAgICAge3stICcgJyArIGNvbnRyb2xzIHwgdG9qc29uKCl9fQogICAgICAgICAgICB7JS0gZW5kaWYgJX0KICAgICAgICB7ey0gJzx8ZW5kX29mX3JvbGV8PicgfX0KICAgIHslLSBlbmRpZiAlfQp7JS0gZW5kZm9yICV9Cg==' +datasets: + - name: dataset_1 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_chat_template_with_masking + arguments: + remove_columns: all + fn_kwargs: + max_seq_length: 1024 + conversation_column: "messages" + - name: dataset_2 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_chat_template_with_masking + arguments: + remove_columns: all + fn_kwargs: + max_seq_length: 1024 + conversation_column: "messages" + - name: dataset_3 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_chat_template_with_masking + arguments: + remove_columns: all + fn_kwargs: + max_seq_length: 1024 + conversation_column: "messages" \ No newline at end of file diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 1dfc553b43..27d61bad8f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -38,6 +38,7 @@ from scripts.run_inference import TunedCausalLM from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_DUPLICATE_COLUMNS, + DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER, DATA_CONFIG_MULTITURN_DATA_YAML, @@ -46,8 +47,10 @@ DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER, + DATA_CONFIG_VALID_BASE64_CHAT_TEMPLATE, DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT, DATA_CONFIG_YAML_STREAMING_PRETOKENIZED, + GRANITE_3_1_B_CHAT_TEMPLATE, ) from tests.artifacts.testdata import ( CHAT_DATA_MULTI_TURN, @@ -84,6 +87,7 @@ DataHandlerConfig, DataPreProcessorConfig, DataSetConfig, + load_and_validate_data_config, ) from tuning.data.data_handlers import ( DataHandler, @@ -1391,6 +1395,25 @@ def test_run_chat_style_ft_using_dataconfig_for_chat_template( assert 'Provide two rhyming words for the word "love"' in output_inference +def test_data_config_chat_template_as_base64(): + """Check that the chat_template specified as base64 is parsed correctly.""" + expected_chat_template_path = GRANITE_3_1_B_CHAT_TEMPLATE + with open(expected_chat_template_path, "r", encoding="utf-8") as f: + expected_chat_template = f.read() + data_config_path = DATA_CONFIG_VALID_BASE64_CHAT_TEMPLATE + assert os.path.isfile(data_config_path) + data_config = load_and_validate_data_config(data_config_path) + parsed_chat_template = data_config.dataprocessor.chat_template + assert parsed_chat_template is not None, "the chat_template wasn't parsed correctly" + assert ( + data_config.dataprocessor.chat_template == expected_chat_template + ), "the chat_template wasn't parsed correctly" + # -------------------------------------------- + with pytest.raises(ValueError): + data_config_path = DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE + data_config = load_and_validate_data_config(data_config_path) + + @pytest.mark.parametrize( "data_args", [ diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index f132e611c7..3371f5c2f5 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -13,6 +13,7 @@ # limitations under the License. # Standard +from base64 import b64decode from dataclasses import dataclass from typing import Dict, List, Optional import logging @@ -153,6 +154,24 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf chat_template = kwargs["chat_template"] assert isinstance(chat_template, str), "chat_template should be a string" c.chat_template = chat_template + elif "chat_template_base64" in kwargs: + chat_template_base64 = kwargs["chat_template_base64"] + assert isinstance( + chat_template_base64, str + ), "chat_template_base64 should be a string" + logger.warning( + "You are using the 'chat_template_base64' field. " + + "Please use the 'chat_template' field instead for better readability." + ) + try: + chat_template_bytes = b64decode(chat_template_base64) + chat_template = chat_template_bytes.decode("utf-8") + c.chat_template = chat_template + except Exception as e: + raise ValueError( + "You passed the 'chat_template_base64' field which failed during decoding." + + "Please check it or use a decoded chat template with the 'chat_template' field." + ) from e return c