Skip to content

Commit d850795

Browse files
committed
lint fix
1 parent 2aa3fa3 commit d850795

2 files changed

Lines changed: 216 additions & 196 deletions

File tree

clarifai_datautils/text/prompt_factory.py

Lines changed: 151 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# This was taken from litellm
22

3+
import json
34
from enum import Enum
45
from typing import Any, Optional
5-
import json
6-
from jinja2.sandbox import ImmutableSandboxedEnvironment
76

87
import requests
8+
from jinja2.sandbox import ImmutableSandboxedEnvironment
99

1010

1111
def default_pt(messages):
@@ -225,6 +225,7 @@ class AnthropicConstants(Enum):
225225
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
226226
return prompt
227227

228+
228229
def anthropic_pt(
229230
messages: list,): # format - https://docs.anthropic.com/claude/reference/complete_post
230231
"""
@@ -378,110 +379,102 @@ def gemini_text_image_pt(messages: list):
378379
return content
379380

380381

381-
def hf_chat_template(model: str, messages: list, hf_token: str, chat_template: Optional[Any] = None):
382-
## get the tokenizer config from huggingface
383-
bos_token = ""
384-
eos_token = ""
385-
if chat_template is None:
386-
def _get_tokenizer_config(hf_model_name):
387-
headers = {
388-
"Authorization": f"Bearer {hf_token}"
389-
}
390-
url = (
391-
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
392-
)
393-
# Make a GET request to fetch the JSON data
394-
response = requests.get(url, headers=headers)
395-
# print(response)
396-
if response.status_code == 200:
397-
# Parse the JSON data
398-
tokenizer_config = json.loads(response.content)
399-
return {"status": "success", "tokenizer": tokenizer_config}
400-
else:
401-
return {"status": "failure"}
402-
403-
tokenizer_config = _get_tokenizer_config(model)
404-
if (
405-
tokenizer_config["status"] == "failure"
406-
or "chat_template" not in tokenizer_config["tokenizer"]
407-
):
408-
raise Exception("No chat template found")
409-
## read the bos token, eos token and chat template from the json
410-
tokenizer_config = tokenizer_config["tokenizer"]
411-
bos_token = tokenizer_config["bos_token"]
412-
eos_token = tokenizer_config["eos_token"]
413-
chat_template = tokenizer_config["chat_template"]
414-
415-
def raise_exception(message):
416-
raise Exception(f"Error message - {message}")
417-
418-
# Create a template object from the template text
419-
env = ImmutableSandboxedEnvironment()
420-
env.globals["raise_exception"] = raise_exception
421-
try:
422-
template = env.from_string(chat_template)
423-
except Exception as e:
424-
raise e
425-
426-
def _is_system_in_template():
427-
try:
428-
# Try rendering the template with a system message
429-
response = template.render(
430-
messages=[{"role": "system", "content": "test"}],
431-
eos_token="<eos>",
432-
bos_token="<bos>",
433-
)
434-
return True
435-
436-
# This will be raised if Jinja attempts to render the system message and it can't
437-
except:
438-
return False
382+
def hf_chat_template(model: str,
383+
messages: list,
384+
hf_token: str,
385+
chat_template: Optional[Any] = None):
386+
## get the tokenizer config from huggingface
387+
bos_token = ""
388+
eos_token = ""
389+
if chat_template is None:
390+
391+
def _get_tokenizer_config(hf_model_name):
392+
headers = {"Authorization": f"Bearer {hf_token}"}
393+
url = (f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json")
394+
# Make a GET request to fetch the JSON data
395+
response = requests.get(url, headers=headers)
396+
# print(response)
397+
if response.status_code == 200:
398+
# Parse the JSON data
399+
tokenizer_config = json.loads(response.content)
400+
return {"status": "success", "tokenizer": tokenizer_config}
401+
else:
402+
return {"status": "failure"}
403+
404+
tokenizer_config = _get_tokenizer_config(model)
405+
if (tokenizer_config["status"] == "failure" or
406+
"chat_template" not in tokenizer_config["tokenizer"]):
407+
raise Exception("No chat template found")
408+
## read the bos token, eos token and chat template from the json
409+
tokenizer_config = tokenizer_config["tokenizer"]
410+
bos_token = tokenizer_config["bos_token"]
411+
eos_token = tokenizer_config["eos_token"]
412+
chat_template = tokenizer_config["chat_template"]
413+
414+
def raise_exception(message):
415+
raise Exception(f"Error message - {message}")
416+
417+
# Create a template object from the template text
418+
env = ImmutableSandboxedEnvironment()
419+
env.globals["raise_exception"] = raise_exception
420+
try:
421+
template = env.from_string(chat_template)
422+
except Exception as e:
423+
raise e
439424

425+
def _is_system_in_template():
440426
try:
441-
# Render the template with the provided values
442-
if _is_system_in_template():
443-
rendered_text = template.render(
444-
bos_token=bos_token, eos_token=eos_token, messages=messages
445-
)
446-
else:
447-
# treat a system message as a user message, if system not in template
448-
try:
449-
reformatted_messages = []
450-
for message in messages:
451-
if message["role"] == "system":
452-
reformatted_messages.append(
453-
{"role": "user", "content": message["content"]}
454-
)
455-
else:
456-
reformatted_messages.append(message)
457-
rendered_text = template.render(
458-
bos_token=bos_token,
459-
eos_token=eos_token,
460-
messages=reformatted_messages,
461-
)
462-
except Exception as e:
463-
if "Conversation roles must alternate user/assistant" in str(e):
464-
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
465-
new_messages = []
466-
for i in range(len(reformatted_messages) - 1):
467-
new_messages.append(reformatted_messages[i])
468-
if (
469-
reformatted_messages[i]["role"]
470-
== reformatted_messages[i + 1]["role"]
471-
):
472-
if reformatted_messages[i]["role"] == "user":
473-
new_messages.append(
474-
{"role": "assistant", "content": ""}
475-
)
476-
else:
477-
new_messages.append({"role": "user", "content": ""})
478-
new_messages.append(reformatted_messages[-1])
479-
rendered_text = template.render(
480-
bos_token=bos_token, eos_token=eos_token, messages=new_messages
481-
)
482-
return rendered_text
483-
except Exception as e:
484-
raise Exception(f"Error rendering template - {str(e)}")
427+
# Try rendering the template with a system message
428+
template.render(
429+
messages=[{
430+
"role": "system",
431+
"content": "test"
432+
}],
433+
eos_token="<eos>",
434+
bos_token="<bos>",
435+
)
436+
return True
437+
438+
# This will be raised if Jinja attempts to render the system message and it can't
439+
except Exception:
440+
return False
441+
442+
try:
443+
# Render the template with the provided values
444+
if _is_system_in_template():
445+
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages)
446+
else:
447+
# treat a system message as a user message, if system not in template
448+
try:
449+
reformatted_messages = []
450+
for message in messages:
451+
if message["role"] == "system":
452+
reformatted_messages.append({"role": "user", "content": message["content"]})
453+
else:
454+
reformatted_messages.append(message)
455+
rendered_text = template.render(
456+
bos_token=bos_token,
457+
eos_token=eos_token,
458+
messages=reformatted_messages,
459+
)
460+
except Exception as e:
461+
if "Conversation roles must alternate user/assistant" in str(e):
462+
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
463+
new_messages = []
464+
for i in range(len(reformatted_messages) - 1):
465+
new_messages.append(reformatted_messages[i])
466+
if (reformatted_messages[i]["role"] == reformatted_messages[i + 1]["role"]):
467+
if reformatted_messages[i]["role"] == "user":
468+
new_messages.append({"role": "assistant", "content": ""})
469+
else:
470+
new_messages.append({"role": "user", "content": ""})
471+
new_messages.append(reformatted_messages[-1])
472+
rendered_text = template.render(
473+
bos_token=bos_token, eos_token=eos_token, messages=new_messages)
474+
return rendered_text
475+
except Exception as e:
476+
raise Exception(f"Error rendering template - {str(e)}")
477+
485478

486479
# Function call template
487480
def function_call_prompt(messages: list, functions: list):
@@ -500,6 +493,7 @@ def function_call_prompt(messages: list, functions: list):
500493

501494
return messages
502495

496+
503497
# Custom prompt template
504498
def custom_prompt(
505499
role_dict: dict,
@@ -533,62 +527,62 @@ def custom_prompt(
533527
prompt += final_prompt_value
534528
return prompt
535529

530+
536531
def prompt_factory(
537532
model: str,
538533
messages: list,
539534
custom_llm_provider: Optional[str] = None,
540535
hf_token: Optional[str] = None,
541536
):
542-
original_model_name = model
543-
model = model.lower()
544-
if custom_llm_provider == "ollama":
545-
return ollama_pt(model=model, messages=messages)
546-
elif custom_llm_provider == "anthropic":
547-
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
548-
return claude_2_1_pt(messages=messages)
549-
else:
550-
return anthropic_pt(messages=messages)
551-
elif custom_llm_provider == "gemini":
552-
if model == "gemini-pro-vision":
553-
return _gemini_vision_convert_messages(messages=messages)
554-
else:
555-
return gemini_text_image_pt(messages=messages)
556-
try:
557-
if "meta-llama/llama-2" in model and "chat" in model:
558-
return llama_2_chat_pt(messages=messages)
559-
elif "llama3" in model and "instruct" in model:
560-
return hf_chat_template(
561-
model="meta-llama/Meta-Llama-3-8B-Instruct",
562-
messages=messages,
563-
)
564-
elif (
565-
"tiiuae/falcon" in model
566-
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
567-
if model == "tiiuae/falcon-180B-chat":
568-
return falcon_chat_pt(messages=messages)
569-
elif "instruct" in model:
570-
return falcon_instruct_pt(messages=messages)
571-
elif "mosaicml/mpt" in model:
572-
if "chat" in model:
573-
return mpt_chat_pt(messages=messages)
574-
elif "codellama/codellama" in model:
575-
if "instruct" in model:
576-
return llama_2_chat_pt(
577-
messages=messages
578-
) # https://huggingface.co/blog/codellama#conversational-instructions
579-
elif "wizardlm/wizardcoder" in model:
580-
return wizardcoder_pt(messages=messages)
581-
elif "phind/phind-codellama" in model:
582-
return phind_codellama_pt(messages=messages)
583-
elif model in [
584-
"gryphe/mythomax-l2-13b",
585-
"gryphe/mythomix-l2-13b",
586-
"gryphe/mythologic-l2-13b",
587-
]:
588-
return alpaca_pt(messages=messages)
589-
else:
590-
return hf_chat_template(original_model_name, messages, hf_token=hf_token)
591-
except Exception as e:
592-
return default_pt(
593-
messages=messages
594-
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
537+
original_model_name = model
538+
model = model.lower()
539+
if custom_llm_provider == "ollama":
540+
return ollama_pt(model=model, messages=messages)
541+
elif custom_llm_provider == "anthropic":
542+
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
543+
return claude_2_1_pt(messages=messages)
544+
else:
545+
return anthropic_pt(messages=messages)
546+
elif custom_llm_provider == "gemini":
547+
if model == "gemini-pro-vision":
548+
return _gemini_vision_convert_messages(messages=messages)
549+
else:
550+
return gemini_text_image_pt(messages=messages)
551+
try:
552+
if "meta-llama/llama-2" in model and "chat" in model:
553+
return llama_2_chat_pt(messages=messages)
554+
elif "llama3" in model and "instruct" in model:
555+
return hf_chat_template(
556+
model="meta-llama/Meta-Llama-3-8B-Instruct",
557+
messages=messages,
558+
)
559+
elif (
560+
"tiiuae/falcon" in
561+
model): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
562+
if model == "tiiuae/falcon-180B-chat":
563+
return falcon_chat_pt(messages=messages)
564+
elif "instruct" in model:
565+
return falcon_instruct_pt(messages=messages)
566+
elif "mosaicml/mpt" in model:
567+
if "chat" in model:
568+
return mpt_chat_pt(messages=messages)
569+
elif "codellama/codellama" in model:
570+
if "instruct" in model:
571+
return llama_2_chat_pt(
572+
messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
573+
elif "wizardlm/wizardcoder" in model:
574+
return wizardcoder_pt(messages=messages)
575+
elif "phind/phind-codellama" in model:
576+
return phind_codellama_pt(messages=messages)
577+
elif model in [
578+
"gryphe/mythomax-l2-13b",
579+
"gryphe/mythomix-l2-13b",
580+
"gryphe/mythologic-l2-13b",
581+
]:
582+
return alpaca_pt(messages=messages)
583+
else:
584+
return hf_chat_template(original_model_name, messages, hf_token=hf_token)
585+
except Exception:
586+
return default_pt(
587+
messages=messages
588+
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

0 commit comments

Comments
 (0)