11# This was taken from litellm
22
3+ import json
34from enum import Enum
45from typing import Any , Optional
5- import json
6- from jinja2 .sandbox import ImmutableSandboxedEnvironment
76
87import requests
8+ from jinja2 .sandbox import ImmutableSandboxedEnvironment
99
1010
1111def default_pt (messages ):
@@ -225,6 +225,7 @@ class AnthropicConstants(Enum):
225225 prompt += f"{ AnthropicConstants .AI_PROMPT .value } " # prompt must end with \"\n\nAssistant: " turn
226226 return prompt
227227
228+
228229def anthropic_pt (
229230 messages : list ,): # format - https://docs.anthropic.com/claude/reference/complete_post
230231 """
@@ -378,110 +379,102 @@ def gemini_text_image_pt(messages: list):
378379 return content
379380
380381
381- def hf_chat_template (model : str , messages : list , hf_token : str , chat_template : Optional [Any ] = None ):
382- ## get the tokenizer config from huggingface
383- bos_token = ""
384- eos_token = ""
385- if chat_template is None :
386- def _get_tokenizer_config (hf_model_name ):
387- headers = {
388- "Authorization" : f"Bearer { hf_token } "
389- }
390- url = (
391- f"https://huggingface.co/{ hf_model_name } /raw/main/tokenizer_config.json"
392- )
393- # Make a GET request to fetch the JSON data
394- response = requests .get (url , headers = headers )
395- # print(response)
396- if response .status_code == 200 :
397- # Parse the JSON data
398- tokenizer_config = json .loads (response .content )
399- return {"status" : "success" , "tokenizer" : tokenizer_config }
400- else :
401- return {"status" : "failure" }
402-
403- tokenizer_config = _get_tokenizer_config (model )
404- if (
405- tokenizer_config ["status" ] == "failure"
406- or "chat_template" not in tokenizer_config ["tokenizer" ]
407- ):
408- raise Exception ("No chat template found" )
409- ## read the bos token, eos token and chat template from the json
410- tokenizer_config = tokenizer_config ["tokenizer" ]
411- bos_token = tokenizer_config ["bos_token" ]
412- eos_token = tokenizer_config ["eos_token" ]
413- chat_template = tokenizer_config ["chat_template" ]
414-
415- def raise_exception (message ):
416- raise Exception (f"Error message - { message } " )
417-
418- # Create a template object from the template text
419- env = ImmutableSandboxedEnvironment ()
420- env .globals ["raise_exception" ] = raise_exception
421- try :
422- template = env .from_string (chat_template )
423- except Exception as e :
424- raise e
425-
426- def _is_system_in_template ():
427- try :
428- # Try rendering the template with a system message
429- response = template .render (
430- messages = [{"role" : "system" , "content" : "test" }],
431- eos_token = "<eos>" ,
432- bos_token = "<bos>" ,
433- )
434- return True
435-
436- # This will be raised if Jinja attempts to render the system message and it can't
437- except :
438- return False
382+ def hf_chat_template (model : str ,
383+ messages : list ,
384+ hf_token : str ,
385+ chat_template : Optional [Any ] = None ):
386+ ## get the tokenizer config from huggingface
387+ bos_token = ""
388+ eos_token = ""
389+ if chat_template is None :
390+
391+ def _get_tokenizer_config (hf_model_name ):
392+ headers = {"Authorization" : f"Bearer { hf_token } " }
393+ url = (f"https://huggingface.co/{ hf_model_name } /raw/main/tokenizer_config.json" )
394+ # Make a GET request to fetch the JSON data
395+ response = requests .get (url , headers = headers )
396+ # print(response)
397+ if response .status_code == 200 :
398+ # Parse the JSON data
399+ tokenizer_config = json .loads (response .content )
400+ return {"status" : "success" , "tokenizer" : tokenizer_config }
401+ else :
402+ return {"status" : "failure" }
403+
404+ tokenizer_config = _get_tokenizer_config (model )
405+ if (tokenizer_config ["status" ] == "failure" or
406+ "chat_template" not in tokenizer_config ["tokenizer" ]):
407+ raise Exception ("No chat template found" )
408+ ## read the bos token, eos token and chat template from the json
409+ tokenizer_config = tokenizer_config ["tokenizer" ]
410+ bos_token = tokenizer_config ["bos_token" ]
411+ eos_token = tokenizer_config ["eos_token" ]
412+ chat_template = tokenizer_config ["chat_template" ]
413+
414+ def raise_exception (message ):
415+ raise Exception (f"Error message - { message } " )
416+
417+ # Create a template object from the template text
418+ env = ImmutableSandboxedEnvironment ()
419+ env .globals ["raise_exception" ] = raise_exception
420+ try :
421+ template = env .from_string (chat_template )
422+ except Exception as e :
423+ raise e
439424
425+ def _is_system_in_template ():
440426 try :
441- # Render the template with the provided values
442- if _is_system_in_template ():
443- rendered_text = template .render (
444- bos_token = bos_token , eos_token = eos_token , messages = messages
445- )
446- else :
447- # treat a system message as a user message, if system not in template
448- try :
449- reformatted_messages = []
450- for message in messages :
451- if message ["role" ] == "system" :
452- reformatted_messages .append (
453- {"role" : "user" , "content" : message ["content" ]}
454- )
455- else :
456- reformatted_messages .append (message )
457- rendered_text = template .render (
458- bos_token = bos_token ,
459- eos_token = eos_token ,
460- messages = reformatted_messages ,
461- )
462- except Exception as e :
463- if "Conversation roles must alternate user/assistant" in str (e ):
464- # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
465- new_messages = []
466- for i in range (len (reformatted_messages ) - 1 ):
467- new_messages .append (reformatted_messages [i ])
468- if (
469- reformatted_messages [i ]["role" ]
470- == reformatted_messages [i + 1 ]["role" ]
471- ):
472- if reformatted_messages [i ]["role" ] == "user" :
473- new_messages .append (
474- {"role" : "assistant" , "content" : "" }
475- )
476- else :
477- new_messages .append ({"role" : "user" , "content" : "" })
478- new_messages .append (reformatted_messages [- 1 ])
479- rendered_text = template .render (
480- bos_token = bos_token , eos_token = eos_token , messages = new_messages
481- )
482- return rendered_text
483- except Exception as e :
484- raise Exception (f"Error rendering template - { str (e )} " )
427+ # Try rendering the template with a system message
428+ template .render (
429+ messages = [{
430+ "role" : "system" ,
431+ "content" : "test"
432+ }],
433+ eos_token = "<eos>" ,
434+ bos_token = "<bos>" ,
435+ )
436+ return True
437+
438+ # This will be raised if Jinja attempts to render the system message and it can't
439+ except Exception :
440+ return False
441+
442+ try :
443+ # Render the template with the provided values
444+ if _is_system_in_template ():
445+ rendered_text = template .render (bos_token = bos_token , eos_token = eos_token , messages = messages )
446+ else :
447+ # treat a system message as a user message, if system not in template
448+ try :
449+ reformatted_messages = []
450+ for message in messages :
451+ if message ["role" ] == "system" :
452+ reformatted_messages .append ({"role" : "user" , "content" : message ["content" ]})
453+ else :
454+ reformatted_messages .append (message )
455+ rendered_text = template .render (
456+ bos_token = bos_token ,
457+ eos_token = eos_token ,
458+ messages = reformatted_messages ,
459+ )
460+ except Exception as e :
461+ if "Conversation roles must alternate user/assistant" in str (e ):
462+ # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
463+ new_messages = []
464+ for i in range (len (reformatted_messages ) - 1 ):
465+ new_messages .append (reformatted_messages [i ])
466+ if (reformatted_messages [i ]["role" ] == reformatted_messages [i + 1 ]["role" ]):
467+ if reformatted_messages [i ]["role" ] == "user" :
468+ new_messages .append ({"role" : "assistant" , "content" : "" })
469+ else :
470+ new_messages .append ({"role" : "user" , "content" : "" })
471+ new_messages .append (reformatted_messages [- 1 ])
472+ rendered_text = template .render (
473+ bos_token = bos_token , eos_token = eos_token , messages = new_messages )
474+ return rendered_text
475+ except Exception as e :
476+ raise Exception (f"Error rendering template - { str (e )} " )
477+
485478
486479# Function call template
487480def function_call_prompt (messages : list , functions : list ):
@@ -500,6 +493,7 @@ def function_call_prompt(messages: list, functions: list):
500493
501494 return messages
502495
496+
503497# Custom prompt template
504498def custom_prompt (
505499 role_dict : dict ,
@@ -533,62 +527,62 @@ def custom_prompt(
533527 prompt += final_prompt_value
534528 return prompt
535529
530+
536531def prompt_factory (
537532 model : str ,
538533 messages : list ,
539534 custom_llm_provider : Optional [str ] = None ,
540535 hf_token : Optional [str ] = None ,
541536):
542- original_model_name = model
543- model = model .lower ()
544- if custom_llm_provider == "ollama" :
545- return ollama_pt (model = model , messages = messages )
546- elif custom_llm_provider == "anthropic" :
547- if any (_ in model for _ in ["claude-2.1" , "claude-v2:1" ]):
548- return claude_2_1_pt (messages = messages )
549- else :
550- return anthropic_pt (messages = messages )
551- elif custom_llm_provider == "gemini" :
552- if model == "gemini-pro-vision" :
553- return _gemini_vision_convert_messages (messages = messages )
554- else :
555- return gemini_text_image_pt (messages = messages )
556- try :
557- if "meta-llama/llama-2" in model and "chat" in model :
558- return llama_2_chat_pt (messages = messages )
559- elif "llama3" in model and "instruct" in model :
560- return hf_chat_template (
561- model = "meta-llama/Meta-Llama-3-8B-Instruct" ,
562- messages = messages ,
563- )
564- elif (
565- "tiiuae/falcon" in model
566- ): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
567- if model == "tiiuae/falcon-180B-chat" :
568- return falcon_chat_pt (messages = messages )
569- elif "instruct" in model :
570- return falcon_instruct_pt (messages = messages )
571- elif "mosaicml/mpt" in model :
572- if "chat" in model :
573- return mpt_chat_pt (messages = messages )
574- elif "codellama/codellama" in model :
575- if "instruct" in model :
576- return llama_2_chat_pt (
577- messages = messages
578- ) # https://huggingface.co/blog/codellama#conversational-instructions
579- elif "wizardlm/wizardcoder" in model :
580- return wizardcoder_pt (messages = messages )
581- elif "phind/phind-codellama" in model :
582- return phind_codellama_pt (messages = messages )
583- elif model in [
584- "gryphe/mythomax-l2-13b" ,
585- "gryphe/mythomix-l2-13b" ,
586- "gryphe/mythologic-l2-13b" ,
587- ]:
588- return alpaca_pt (messages = messages )
589- else :
590- return hf_chat_template (original_model_name , messages , hf_token = hf_token )
591- except Exception as e :
592- return default_pt (
593- messages = messages
594- ) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
537+ original_model_name = model
538+ model = model .lower ()
539+ if custom_llm_provider == "ollama" :
540+ return ollama_pt (model = model , messages = messages )
541+ elif custom_llm_provider == "anthropic" :
542+ if any (_ in model for _ in ["claude-2.1" , "claude-v2:1" ]):
543+ return claude_2_1_pt (messages = messages )
544+ else :
545+ return anthropic_pt (messages = messages )
546+ elif custom_llm_provider == "gemini" :
547+ if model == "gemini-pro-vision" :
548+ return _gemini_vision_convert_messages (messages = messages )
549+ else :
550+ return gemini_text_image_pt (messages = messages )
551+ try :
552+ if "meta-llama/llama-2" in model and "chat" in model :
553+ return llama_2_chat_pt (messages = messages )
554+ elif "llama3" in model and "instruct" in model :
555+ return hf_chat_template (
556+ model = "meta-llama/Meta-Llama-3-8B-Instruct" ,
557+ messages = messages ,
558+ )
559+ elif (
560+ "tiiuae/falcon" in
561+ model ): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
562+ if model == "tiiuae/falcon-180B-chat" :
563+ return falcon_chat_pt (messages = messages )
564+ elif "instruct" in model :
565+ return falcon_instruct_pt (messages = messages )
566+ elif "mosaicml/mpt" in model :
567+ if "chat" in model :
568+ return mpt_chat_pt (messages = messages )
569+ elif "codellama/codellama" in model :
570+ if "instruct" in model :
571+ return llama_2_chat_pt (
572+ messages = messages ) # https://huggingface.co/blog/codellama#conversational-instructions
573+ elif "wizardlm/wizardcoder" in model :
574+ return wizardcoder_pt (messages = messages )
575+ elif "phind/phind-codellama" in model :
576+ return phind_codellama_pt (messages = messages )
577+ elif model in [
578+ "gryphe/mythomax-l2-13b" ,
579+ "gryphe/mythomix-l2-13b" ,
580+ "gryphe/mythologic-l2-13b" ,
581+ ]:
582+ return alpaca_pt (messages = messages )
583+ else :
584+ return hf_chat_template (original_model_name , messages , hf_token = hf_token )
585+ except Exception :
586+ return default_pt (
587+ messages = messages
588+ ) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
0 commit comments