From 31d02c02a904b483fa915cf4da8cae71cd0f2916 Mon Sep 17 00:00:00 2001 From: Octopus Date: Wed, 1 Apr 2026 21:08:14 +0800 Subject: [PATCH] Add MiniMax as LLM provider for layout generation - Add MiniMax-M2.7, MiniMax-M2.7-highspeed, MiniMax-M2.5, MiniMax-M2.5-highspeed to model_names list - Detect minimax models in get_llm_kwargs(): use https://api.minimax.io/v1 base URL and MINIMAX_API_KEY env var for authentication - Update get_layout() to use chat/completions endpoint and message content parsing for MiniMax models (OpenAI-compatible API) - Add minimax_api_key to utils/api_key.py with MINIMAX_API_KEY env var support - Add README section documenting MiniMax usage (Option 1b) - Add tests/test_minimax.py: 25 unit tests + 3 integration tests MiniMax (https://www.minimax.io/) provides powerful cloud LLM models (MiniMax-M2.7, MiniMax-M2.5, 204K context) via an OpenAI-compatible API. --- README.md | 8 + tests/test_minimax.py | 352 ++++++++++++++++++++++++++++++++++++++++++ utils/api_key.py | 6 + utils/llm.py | 13 +- 4 files changed, 376 insertions(+), 3 deletions(-) create mode 100644 tests/test_minimax.py diff --git a/README.md b/README.md index d7bbefb..e8b7433 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,14 @@ python prompt_batch.py --prompt-type demo --model gpt-4 --auto-query --always-sa You can visualize the bounding boxes in `img_generations/imgs_demo_templatev0.1`. +### Option 1b (automated): Use a MiniMax API key +[MiniMax](https://www.minimax.io/) offers powerful cloud LLM models (MiniMax-M2.7, MiniMax-M2.7-highspeed) via an OpenAI-compatible API. Set your API key and run: +``` +export MINIMAX_API_KEY=your_minimax_api_key +python prompt_batch.py --prompt-type demo --model MiniMax-M2.7 --auto-query --always-save --template_version v0.1 +``` +Available MiniMax models: `MiniMax-M2.7`, `MiniMax-M2.7-highspeed`, `MiniMax-M2.5`, `MiniMax-M2.5-highspeed`. + ### Option 2 (free): Manually copy and paste to ChatGPT ``` python prompt_batch.py --prompt-type demo --model gpt-4 --always-save --template_version v0.1 diff --git a/tests/test_minimax.py b/tests/test_minimax.py new file mode 100644 index 0000000..74102a1 --- /dev/null +++ b/tests/test_minimax.py @@ -0,0 +1,352 @@ +"""Tests for MiniMax LLM provider support.""" +import os +import sys +import json +import unittest +from unittest.mock import patch, MagicMock + +# Add parent directory to path so we can import project modules +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class TestMiniMaxModelNames(unittest.TestCase): + """Unit tests for MiniMax model name handling.""" + + def test_model_names_contains_minimax_m27(self): + from utils.llm import model_names + self.assertIn("MiniMax-M2.7", model_names) + + def test_model_names_contains_minimax_m27_highspeed(self): + from utils.llm import model_names + self.assertIn("MiniMax-M2.7-highspeed", model_names) + + def test_model_names_contains_minimax_m25(self): + from utils.llm import model_names + self.assertIn("MiniMax-M2.5", model_names) + + def test_model_names_contains_minimax_m25_highspeed(self): + from utils.llm import model_names + self.assertIn("MiniMax-M2.5-highspeed", model_names) + + def test_existing_models_still_present(self): + from utils.llm import model_names + for model in ["gpt-3.5-turbo", "gpt-4", "vicuna-13b"]: + self.assertIn(model, model_names) + + +class TestGetFullModelName(unittest.TestCase): + """Unit tests for get_full_model_name().""" + + def test_minimax_m27_unchanged(self): + from utils.llm import get_full_model_name + self.assertEqual(get_full_model_name("MiniMax-M2.7"), "MiniMax-M2.7") + + def test_minimax_m27_highspeed_unchanged(self): + from utils.llm import get_full_model_name + self.assertEqual(get_full_model_name("MiniMax-M2.7-highspeed"), "MiniMax-M2.7-highspeed") + + def test_minimax_m25_unchanged(self): + from utils.llm import get_full_model_name + self.assertEqual(get_full_model_name("MiniMax-M2.5"), "MiniMax-M2.5") + + def test_minimax_m25_highspeed_unchanged(self): + from utils.llm import get_full_model_name + self.assertEqual(get_full_model_name("MiniMax-M2.5-highspeed"), "MiniMax-M2.5-highspeed") + + def test_gpt35_shorthand_still_expands(self): + from utils.llm import get_full_model_name + self.assertEqual(get_full_model_name("gpt-3.5"), "gpt-3.5-turbo") + + +class TestMiniMaxApiKey(unittest.TestCase): + """Unit tests for api_key.py MiniMax key handling.""" + + def test_minimax_api_key_from_env(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-minimax-key-123"}): + import importlib + import utils.api_key as api_key_module + importlib.reload(api_key_module) + self.assertEqual(api_key_module.minimax_api_key, "test-minimax-key-123") + + def test_minimax_api_key_default_without_env(self): + env = {k: v for k, v in os.environ.items() if k != "MINIMAX_API_KEY"} + with patch.dict(os.environ, env, clear=True): + import importlib + import utils.api_key as api_key_module + importlib.reload(api_key_module) + self.assertEqual(api_key_module.minimax_api_key, "YOUR_MINIMAX_API_KEY") + + +class TestGetLlmKwargsMiniMax(unittest.TestCase): + """Unit tests for get_llm_kwargs() with MiniMax models.""" + + def _get_kwargs(self, model_name, minimax_key="fake-minimax-key"): + with patch.dict(os.environ, {"MINIMAX_API_KEY": minimax_key}): + import importlib + import utils.api_key as api_key_module + importlib.reload(api_key_module) + with patch("utils.llm.get_llm_kwargs.__globals__", {}): + pass + from utils.llm import get_llm_kwargs + import utils.llm as llm_module + # Patch the imported minimax_api_key inside the module + with patch.object(api_key_module, "minimax_api_key", minimax_key): + model, kwargs = get_llm_kwargs(model_name, "v0.1") + return model, kwargs + + def test_minimax_m27_uses_minimax_api_base(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "fake-key"}): + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs + model, kwargs = get_llm_kwargs("MiniMax-M2.7", "v0.1") + self.assertEqual(kwargs.api_base, "https://api.minimax.io/v1") + + def test_minimax_m27_highspeed_uses_minimax_api_base(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "fake-key"}): + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs + model, kwargs = get_llm_kwargs("MiniMax-M2.7-highspeed", "v0.1") + self.assertEqual(kwargs.api_base, "https://api.minimax.io/v1") + + def test_minimax_m25_uses_minimax_api_base(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "fake-key"}): + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs + model, kwargs = get_llm_kwargs("MiniMax-M2.5", "v0.1") + self.assertEqual(kwargs.api_base, "https://api.minimax.io/v1") + + def test_minimax_uses_bearer_auth_header(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "my-minimax-key-xyz"}): + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs + model, kwargs = get_llm_kwargs("MiniMax-M2.7", "v0.1") + self.assertIn("Authorization", kwargs.headers) + self.assertIn("my-minimax-key-xyz", kwargs.headers["Authorization"]) + + def test_minimax_temperature_valid(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "fake-key"}): + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs + model, kwargs = get_llm_kwargs("MiniMax-M2.7", "v0.1") + self.assertGreater(kwargs.temperature, 0.0) + self.assertLessEqual(kwargs.temperature, 1.0) + + def test_minimax_model_name_preserved(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "fake-key"}): + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs + model, kwargs = get_llm_kwargs("MiniMax-M2.7", "v0.1") + self.assertEqual(model, "MiniMax-M2.7") + self.assertEqual(kwargs.model, "MiniMax-M2.7") + + def test_openai_still_uses_openai_base(self): + with patch.dict(os.environ, {"OPENAI_API_KEY": "fake-openai-key"}): + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs + model, kwargs = get_llm_kwargs("gpt-4", "v0.1") + self.assertEqual(kwargs.api_base, "https://api.openai.com/v1") + + +class TestGetLayoutMiniMax(unittest.TestCase): + """Unit tests for get_layout() using MiniMax (chat completions format).""" + + def _make_chat_response(self, content): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": content}}] + } + return mock_response + + def test_minimax_uses_chat_completions_endpoint(self): + from utils.llm import get_layout + from easydict import EasyDict + + mock_resp = self._make_chat_response("Objects: [('a cat', [50, 50, 100, 100])]\nBackground prompt: A room\nNegative prompt: ") + with patch("requests.post", return_value=mock_resp) as mock_post: + kwargs = EasyDict( + model="MiniMax-M2.7", + template="{prompt}", + api_base="https://api.minimax.io/v1", + max_tokens=900, + temperature=0.25, + stop=None, + headers={"Authorization": "Bearer fake-key"}, + ) + result = get_layout("a cat in a room", kwargs) + + call_url = mock_post.call_args[0][0] + self.assertIn("chat/completions", call_url) + + def test_minimax_response_parsed_from_message_content(self): + from utils.llm import get_layout + from easydict import EasyDict + + expected_content = "Objects: [('a dog', [10, 10, 200, 200])]\nBackground prompt: park\nNegative prompt: " + mock_resp = self._make_chat_response(expected_content) + with patch("requests.post", return_value=mock_resp): + kwargs = EasyDict( + model="MiniMax-M2.7", + template="{prompt}", + api_base="https://api.minimax.io/v1", + max_tokens=900, + temperature=0.25, + stop=None, + headers={"Authorization": "Bearer fake-key"}, + ) + result = get_layout("a dog in a park", kwargs) + + self.assertEqual(result, expected_content) + + def test_minimax_highspeed_uses_chat_completions(self): + from utils.llm import get_layout + from easydict import EasyDict + + mock_resp = self._make_chat_response("Objects: []\nBackground prompt: sky\nNegative prompt: ") + with patch("requests.post", return_value=mock_resp) as mock_post: + kwargs = EasyDict( + model="MiniMax-M2.7-highspeed", + template="{prompt}", + api_base="https://api.minimax.io/v1", + max_tokens=900, + temperature=0.25, + stop=None, + headers={"Authorization": "Bearer fake-key"}, + ) + get_layout("a clear sky", kwargs) + + call_url = mock_post.call_args[0][0] + self.assertIn("chat/completions", call_url) + + def test_minimax_request_body_contains_model(self): + from utils.llm import get_layout + from easydict import EasyDict + + mock_resp = self._make_chat_response("Objects: []\nBackground prompt: beach\nNegative prompt: ") + with patch("requests.post", return_value=mock_resp) as mock_post: + kwargs = EasyDict( + model="MiniMax-M2.5", + template="{prompt}", + api_base="https://api.minimax.io/v1", + max_tokens=900, + temperature=0.5, + stop=None, + headers={"Authorization": "Bearer fake-key"}, + ) + get_layout("a beach", kwargs) + + call_json = mock_post.call_args[1]["json"] + self.assertEqual(call_json["model"], "MiniMax-M2.5") + + def test_minimax_request_body_contains_messages(self): + from utils.llm import get_layout + from easydict import EasyDict + + mock_resp = self._make_chat_response("Objects: []\nBackground prompt: forest\nNegative prompt: ") + with patch("requests.post", return_value=mock_resp) as mock_post: + kwargs = EasyDict( + model="MiniMax-M2.7", + template="{prompt}", + api_base="https://api.minimax.io/v1", + max_tokens=900, + temperature=0.25, + stop=None, + headers={"Authorization": "Bearer fake-key"}, + ) + get_layout("a forest", kwargs) + + call_json = mock_post.call_args[1]["json"] + self.assertIn("messages", call_json) + self.assertEqual(call_json["messages"][0]["role"], "user") + + def test_gpt_still_uses_chat_completions(self): + from utils.llm import get_layout + from easydict import EasyDict + + mock_resp = self._make_chat_response("Objects: []\nBackground prompt: city\nNegative prompt: ") + with patch("requests.post", return_value=mock_resp) as mock_post: + kwargs = EasyDict( + model="gpt-4", + template="{prompt}", + api_base="https://api.openai.com/v1", + max_tokens=900, + temperature=0.25, + stop=None, + headers={"Authorization": "Bearer fake-openai-key"}, + ) + get_layout("a city", kwargs) + + call_url = mock_post.call_args[0][0] + self.assertIn("chat/completions", call_url) + + +class TestMiniMaxIntegration(unittest.TestCase): + """Integration tests for MiniMax API (require MINIMAX_API_KEY env var).""" + + @unittest.skipUnless(os.environ.get("MINIMAX_API_KEY"), "MINIMAX_API_KEY not set") + def test_minimax_m27_chat_completions_live(self): + """Live test: verify MiniMax-M2.7 returns a valid chat completion.""" + import requests + api_key = os.environ["MINIMAX_API_KEY"] + resp = requests.post( + "https://api.minimax.io/v1/chat/completions", + json={ + "model": "MiniMax-M2.7", + "messages": [{"role": "user", "content": "Say hello in one word."}], + "max_tokens": 10, + "temperature": 0.5, + }, + headers={"Authorization": f"Bearer {api_key}"}, + timeout=30, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + self.assertIn("message", data["choices"][0]) + self.assertIn("content", data["choices"][0]["message"]) + + @unittest.skipUnless(os.environ.get("MINIMAX_API_KEY"), "MINIMAX_API_KEY not set") + def test_minimax_m27_highspeed_chat_completions_live(self): + """Live test: verify MiniMax-M2.7-highspeed returns a valid chat completion.""" + import requests + api_key = os.environ["MINIMAX_API_KEY"] + resp = requests.post( + "https://api.minimax.io/v1/chat/completions", + json={ + "model": "MiniMax-M2.7-highspeed", + "messages": [{"role": "user", "content": "Say hello in one word."}], + "max_tokens": 10, + "temperature": 0.5, + }, + headers={"Authorization": f"Bearer {api_key}"}, + timeout=30, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + + @unittest.skipUnless(os.environ.get("MINIMAX_API_KEY"), "MINIMAX_API_KEY not set") + def test_get_llm_kwargs_minimax_live_request(self): + """Integration test: get_llm_kwargs() + get_layout() with MiniMax.""" + import importlib, utils.api_key + importlib.reload(utils.api_key) + from utils.llm import get_llm_kwargs, get_layout + + model, kwargs = get_llm_kwargs("MiniMax-M2.7", "v0.1") + self.assertEqual(kwargs.api_base, "https://api.minimax.io/v1") + + prompt = "A cat sitting on a table" + response = get_layout(prompt, kwargs) + self.assertIsInstance(response, str) + self.assertGreater(len(response), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/utils/api_key.py b/utils/api_key.py index bffda02..31d8149 100644 --- a/utils/api_key.py +++ b/utils/api_key.py @@ -5,3 +5,9 @@ api_key = os.environ["OPENAI_API_KEY"] else: api_key = "YOUR_API_KEY" + +# MiniMax API key: set `MINIMAX_API_KEY` environment variable or replace "YOUR_MINIMAX_API_KEY" below +if "MINIMAX_API_KEY" in os.environ: + minimax_api_key = os.environ["MINIMAX_API_KEY"] +else: + minimax_api_key = "YOUR_MINIMAX_API_KEY" diff --git a/utils/llm.py b/utils/llm.py index b3bc29f..6f35fe8 100644 --- a/utils/llm.py +++ b/utils/llm.py @@ -6,7 +6,7 @@ import traceback import time -model_names = ["vicuna", "vicuna-13b", "vicuna-13b-v1.3", "vicuna-33b-v1.3", "Llama-2-7b-hf", "Llama-2-13b-hf", "Llama-2-70b-hf", "FreeWilly2", "StableBeluga2", "gpt-3.5-turbo", "gpt-3.5", "gpt-4", "text-davinci-003", "Mixtral-8x7B-Instruct-v0.1"] +model_names = ["vicuna", "vicuna-13b", "vicuna-13b-v1.3", "vicuna-33b-v1.3", "Llama-2-7b-hf", "Llama-2-13b-hf", "Llama-2-70b-hf", "FreeWilly2", "StableBeluga2", "gpt-3.5-turbo", "gpt-3.5", "gpt-4", "text-davinci-003", "Mixtral-8x7B-Instruct-v0.1", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.7", "MiniMax-M2.7-highspeed"] def get_full_prompt(template, prompt, suffix=None): full_prompt = template.format(prompt=prompt) @@ -36,6 +36,13 @@ def get_llm_kwargs(model, template_version): max_tokens = 900 temperature = 0.25 headers = {} + elif "minimax" in model.lower(): + from utils.api_key import minimax_api_key + + api_base = "https://api.minimax.io/v1" + max_tokens = 900 + temperature = 0.25 + headers = {"Authorization": f"Bearer {minimax_api_key}"} else: from utils.api_key import api_key @@ -56,7 +63,7 @@ def get_layout(prompt, llm_kwargs, suffix=""): done = False attempts = 0 while not done: - if "gpt" in model: + if "gpt" in model or "minimax" in model.lower(): r = requests.post(f'{api_base}/chat/completions', json={ "model": model, "messages": [{"role": "user", "content": get_full_prompt(template, prompt, suffix).strip()}], @@ -85,7 +92,7 @@ def get_layout(prompt, llm_kwargs, suffix=""): print("Exiting due to many non-successful attempts") exit() - if "gpt" in model: + if "gpt" in model or "minimax" in model.lower(): response = r.json()['choices'][0]['message']['content'] else: response = r.json()['choices'][0]['text']