From 707f2809800bf8f1f50c90cada5467bc95f9095b Mon Sep 17 00:00:00 2001 From: Yiiii0 Date: Mon, 9 Mar 2026 00:04:34 -0400 Subject: [PATCH] feat: Add Forge as LLM provider --- .../sotopia/generation_utils/generate.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/EPO/Sotopia/sotopia/generation_utils/generate.py b/EPO/Sotopia/sotopia/generation_utils/generate.py index 70700fe1..e48f173c 100644 --- a/EPO/Sotopia/sotopia/generation_utils/generate.py +++ b/EPO/Sotopia/sotopia/generation_utils/generate.py @@ -508,6 +508,24 @@ def obtain_chain( ) chain = chat_prompt_template | chat_openai return chain + elif model_name.startswith("forge"): + model_name = "/".join(model_name.split("/")[1:]) + human_message_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=template, + input_variables=input_variables, + ) + ) + chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) + chat_openai = ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_retries=max_retries, + openai_api_base=os.environ.get("FORGE_API_BASE", "https://api.forge.tensorblock.co/v1"), + openai_api_key=os.environ.get("FORGE_API_KEY"), + ) + chain = chat_prompt_template | chat_openai + return chain elif model_name.startswith("groq"): model_name = "/".join(model_name.split("/")[1:]) human_message_prompt = HumanMessagePromptTemplate( @@ -1075,4 +1093,3 @@ async def agenerate_strategy(strategy_model: str, agent: str, history: str) -> s ) -