Skip to content

Commit aa9e3df

Browse files
authored
Merge pull request #163 from majiayu000/feat/ollama-support
feat: add Ollama provider support
2 parents 5caa76c + 000dd37 commit aa9e3df

2 files changed

Lines changed: 135 additions & 0 deletions

File tree

gui_agents/s3/core/mllm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,69 @@ def __init__(self, engine_params=None, system_prompt=None, engine=None):
3535
self.engine = LMMEngineOpenRouter(**engine_params)
3636
elif engine_type == "parasail":
3737
self.engine = LMMEngineParasail(**engine_params)
38+
elif engine_type == "ollama":
39+
# Reuse LMMEngineOpenAI for Ollama
40+
if not engine_params.get("base_url"):
41+
import os
42+
43+
base_url = os.getenv("OLLAMA_HOST")
44+
if base_url:
45+
if not base_url.endswith("/v1"):
46+
base_url = base_url.rstrip("/") + "/v1"
47+
engine_params["base_url"] = base_url
48+
else:
49+
# RAISE ERROR instead of default
50+
raise ValueError(
51+
"Ollama endpoint must be provided via 'base_url' parameter or 'OLLAMA_HOST' environment variable."
52+
)
53+
if not engine_params.get("api_key"):
54+
engine_params["api_key"] = "ollama"
55+
self.engine = LMMEngineOpenAI(**engine_params)
56+
elif engine_type == "deepseek":
57+
if "base_url" not in engine_params:
58+
import os
59+
60+
base_url = os.getenv("DEEPSEEK_ENDPOINT_URL")
61+
if not base_url:
62+
base_url = "https://api.deepseek.com"
63+
if not base_url.endswith("/v1"):
64+
base_url = base_url.rstrip("/") + "/v1"
65+
engine_params["base_url"] = base_url
66+
67+
if not engine_params.get("api_key"):
68+
import os
69+
70+
api_key = os.getenv("DEEPSEEK_API_KEY")
71+
if not api_key:
72+
raise ValueError(
73+
"DeepSeek API key must be provided via 'api_key' parameter or 'DEEPSEEK_API_KEY' environment variable."
74+
)
75+
engine_params["api_key"] = api_key
76+
77+
self.engine = LMMEngineOpenAI(**engine_params)
78+
elif engine_type == "qwen":
79+
if not engine_params.get("base_url"):
80+
import os
81+
82+
base_url = os.getenv("QWEN_ENDPOINT_URL")
83+
if not base_url:
84+
base_url = (
85+
"https://dashscope.aliyuncs.com/compatible-mode/v1"
86+
)
87+
if not base_url.endswith("/v1"):
88+
base_url = base_url.rstrip("/") + "/v1"
89+
engine_params["base_url"] = base_url
90+
91+
if not engine_params.get("api_key"):
92+
import os
93+
94+
api_key = os.getenv("QWEN_API_KEY")
95+
if not api_key:
96+
raise ValueError(
97+
"Qwen API key must be provided via 'api_key' parameter or 'QWEN_API_KEY' environment variable."
98+
)
99+
engine_params["api_key"] = api_key
100+
self.engine = LMMEngineOpenAI(**engine_params)
38101
else:
39102
raise ValueError(f"engine_type '{engine_type}' is not supported")
40103
else:

tests/test_providers.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import os
2+
import unittest
3+
from unittest.mock import patch, MagicMock
4+
from gui_agents.s3.core.mllm import LMMAgent
5+
from gui_agents.s3.core.engine import LMMEngineOpenAI
6+
7+
8+
class TestProviders(unittest.TestCase):
9+
def setUp(self):
10+
# Clear env vars before each test
11+
if "OLLAMA_HOST" in os.environ:
12+
del os.environ["OLLAMA_HOST"]
13+
if "DEEPSEEK_API_KEY" in os.environ:
14+
del os.environ["DEEPSEEK_API_KEY"]
15+
if "QWEN_API_KEY" in os.environ:
16+
del os.environ["QWEN_API_KEY"]
17+
if "DEEPSEEK_ENDPOINT_URL" in os.environ:
18+
del os.environ["DEEPSEEK_ENDPOINT_URL"]
19+
if "QWEN_ENDPOINT_URL" in os.environ:
20+
del os.environ["QWEN_ENDPOINT_URL"]
21+
22+
def test_ollama_missing_config(self):
23+
"""Test that Ollama raises ValueError if no endpoint is provided"""
24+
with self.assertRaises(ValueError) as cm:
25+
LMMAgent(engine_params={"engine_type": "ollama", "model": "llama3"})
26+
self.assertIn("Ollama endpoint must be provided", str(cm.exception))
27+
28+
def test_ollama_valid_config_param(self):
29+
"""Test Ollama init with base_url param"""
30+
agent = LMMAgent(
31+
engine_params={
32+
"engine_type": "ollama",
33+
"model": "llama3",
34+
"base_url": "http://example.com/v1",
35+
}
36+
)
37+
self.assertIsInstance(agent.engine, LMMEngineOpenAI)
38+
self.assertEqual(agent.engine.base_url, "http://example.com/v1")
39+
40+
def test_ollama_valid_config_env(self):
41+
"""Test Ollama init with OLLAMA_HOST env var"""
42+
with patch.dict(os.environ, {"OLLAMA_HOST": "http://env-host:11434"}):
43+
agent = LMMAgent(engine_params={"engine_type": "ollama", "model": "llama3"})
44+
self.assertIsInstance(agent.engine, LMMEngineOpenAI)
45+
# Check for /v1 addition
46+
self.assertEqual(agent.engine.base_url, "http://env-host:11434/v1")
47+
48+
def test_deepseek_init(self):
49+
"""Test DeepSeek initialization"""
50+
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "sk-test"}):
51+
agent = LMMAgent(
52+
engine_params={"engine_type": "deepseek", "model": "deepseek-coder"}
53+
)
54+
self.assertIsInstance(agent.engine, LMMEngineOpenAI)
55+
# Default URL
56+
self.assertEqual(agent.engine.base_url, "https://api.deepseek.com/v1")
57+
# (Note: engine.py logic resolves default at generate() time or if client created,
58+
# but init just stores what's passed. Let's verify prompt generation to ensure it doesn't crash on init)
59+
60+
def test_qwen_init(self):
61+
"""Test Qwen initialization"""
62+
with patch.dict(os.environ, {"QWEN_API_KEY": "sk-qwen"}):
63+
agent = LMMAgent(engine_params={"engine_type": "qwen", "model": "qwen-max"})
64+
self.assertIsInstance(agent.engine, LMMEngineOpenAI)
65+
self.assertEqual(
66+
agent.engine.base_url,
67+
"https://dashscope.aliyuncs.com/compatible-mode/v1",
68+
)
69+
70+
71+
if __name__ == "__main__":
72+
unittest.main()

0 commit comments

Comments
 (0)