diff --git a/README.md b/README.md index 381235b90..e11ef5812 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ This package integrates Large Language Models (LLMs) into [spaCy](https://spacy. - **[Anthropic](https://docs.anthropic.com/claude/reference/)** - **[Google PaLM](https://ai.google/discover/palm2/)** - **[Microsoft Azure AI](https://azure.microsoft.com/en-us/solutions/ai)** + - **[MiniMax](https://platform.minimaxi.com/)** - Supports open-source LLMs hosted on Hugging Face 🤗: - **[Falcon](https://huggingface.co/tiiuae)** - **[Dolly](https://huggingface.co/databricks)** diff --git a/spacy_llm/models/__init__.py b/spacy_llm/models/__init__.py index c14270097..eb35571fb 100644 --- a/spacy_llm/models/__init__.py +++ b/spacy_llm/models/__init__.py @@ -1,10 +1,11 @@ from .hf import dolly_hf, openllama_hf, stablelm_hf from .langchain import query_langchain -from .rest import anthropic, cohere, noop, openai, palm +from .rest import anthropic, cohere, minimax, noop, openai, palm __all__ = [ "anthropic", "cohere", + "minimax", "openai", "dolly_hf", "noop", diff --git a/spacy_llm/models/rest/__init__.py b/spacy_llm/models/rest/__init__.py index 96263967c..1dcc43ec0 100644 --- a/spacy_llm/models/rest/__init__.py +++ b/spacy_llm/models/rest/__init__.py @@ -1,10 +1,11 @@ -from . import anthropic, azure, base, cohere, noop, openai +from . import anthropic, azure, base, cohere, minimax, noop, openai __all__ = [ "anthropic", "azure", "base", "cohere", + "minimax", "openai", "noop", ] diff --git a/spacy_llm/models/rest/minimax/__init__.py b/spacy_llm/models/rest/minimax/__init__.py new file mode 100644 index 000000000..1cf84e86a --- /dev/null +++ b/spacy_llm/models/rest/minimax/__init__.py @@ -0,0 +1,4 @@ +from .model import Endpoints, MiniMax +from .registry import minimax_v1 + +__all__ = ["MiniMax", "Endpoints", "minimax_v1"] diff --git a/spacy_llm/models/rest/minimax/model.py b/spacy_llm/models/rest/minimax/model.py new file mode 100644 index 000000000..1a04b0a05 --- /dev/null +++ b/spacy_llm/models/rest/minimax/model.py @@ -0,0 +1,119 @@ +import os +import re +import warnings +from enum import Enum +from typing import Any, Dict, Iterable, List, Sized + +import requests # type: ignore[import] +import srsly # type: ignore[import] +from requests import HTTPError + +from ..base import REST + + +class Endpoints(str, Enum): + CHAT = "https://api.minimax.io/v1/chat/completions" + + +class MiniMax(REST): + @property + def credentials(self) -> Dict[str, str]: + api_key = os.getenv("MINIMAX_API_KEY") + if api_key is None: + warnings.warn( + "Could not find the API key to access the MiniMax API. Ensure you have an API key " + "set up via https://platform.minimaxi.com/, then make it available as " + "an environment variable 'MINIMAX_API_KEY'." + ) + + return { + "Authorization": f"Bearer {api_key}", + } + + def _verify_auth(self) -> None: + # MiniMax doesn't have a dedicated /v1/models endpoint, so we verify + # auth by making a minimal test request. + try: + self([["test"]]) + except ValueError as err: + if "authentication" in str(err).lower() or "api key" in str(err).lower(): + warnings.warn( + "Authentication with provided API key failed. Please double-check you provided " + "the correct credentials." + ) + else: + raise err + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + headers = { + **self._credentials, + "Content-Type": "application/json", + } + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self._endpoint, + headers=headers, + json={**json_data, **self._config, "model": self._name}, + timeout=self._max_request_time, + ) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + raise ValueError( + f"Request to MiniMax API failed: " + f"{res_content.get('error', {}).get('message', str(res_content))}" + ) from ex + responses = r.json() + + if "error" in responses: + if self._strict: + raise ValueError(f"API call failed: {responses}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(responses)] + * len(prompts_for_doc) + } + + return responses + + # MiniMax uses an OpenAI-compatible chat completions API, + # so we send individual requests per prompt (no batching). + for prompt in prompts_for_doc: + responses = _request( + {"messages": [{"role": "user", "content": prompt}]} + ) + if "error" in responses: + return responses["error"] + + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + content = response.get("message", {}).get( + "content", srsly.json_dumps(response) + ) + # Strip ... tags from thinking models. + content = re.sub( + r".*?\s*", "", content, flags=re.DOTALL + ) + api_responses.append(content) + + all_api_responses.append(api_responses) + + return all_api_responses + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "MiniMax-M2.7": 1048576, + "MiniMax-M2.7-highspeed": 1048576, + "MiniMax-M2.5": 204800, + "MiniMax-M2.5-highspeed": 204800, + } diff --git a/spacy_llm/models/rest/minimax/registry.py b/spacy_llm/models/rest/minimax/registry.py new file mode 100644 index 000000000..aded7f248 --- /dev/null +++ b/spacy_llm/models/rest/minimax/registry.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, Optional + +from confection import SimpleFrozenDict + +from ....registry import registry +from .model import Endpoints, MiniMax + +_DEFAULT_TEMPERATURE = 0.0 + + +@registry.llm_models("spacy.MiniMax.v1") +def minimax_v1( + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + name: str = "MiniMax-M2.5", + strict: bool = MiniMax.DEFAULT_STRICT, + max_tries: int = MiniMax.DEFAULT_MAX_TRIES, + interval: float = MiniMax.DEFAULT_INTERVAL, + max_request_time: float = MiniMax.DEFAULT_MAX_REQUEST_TIME, + endpoint: Optional[str] = None, + context_length: Optional[int] = None, +) -> MiniMax: + """Returns MiniMax instance for MiniMax models using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Model name to use. Can be any model name supported by the MiniMax API - e.g. + 'MiniMax-M2.5', 'MiniMax-M2.5-highspeed', 'MiniMax-M2.7', 'MiniMax-M2.7-highspeed'. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response. + If False, the API error responses are returned by __call__(), but no error will be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 + exponential backoff at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising + an exception. + endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and + if no context length natively provided by spacy-llm. + RETURNS (MiniMax): MiniMax instance. + + DOCS: https://platform.minimaxi.com/ + """ + return MiniMax( + name=name, + endpoint=endpoint or Endpoints.CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) diff --git a/spacy_llm/tests/compat.py b/spacy_llm/tests/compat.py index 21688a71f..74e5ef55f 100644 --- a/spacy_llm/tests/compat.py +++ b/spacy_llm/tests/compat.py @@ -5,3 +5,4 @@ has_cohere_key = os.getenv("CO_API_KEY") is not None has_azure_openai_key = os.getenv("AZURE_OPENAI_KEY") is not None has_palm_key = os.getenv("PALM_API_KEY") is not None +has_minimax_key = os.getenv("MINIMAX_API_KEY") is not None diff --git a/spacy_llm/tests/models/test_minimax.py b/spacy_llm/tests/models/test_minimax.py new file mode 100644 index 000000000..04af000b1 --- /dev/null +++ b/spacy_llm/tests/models/test_minimax.py @@ -0,0 +1,66 @@ +# mypy: ignore-errors +import pytest + +from spacy_llm.models.rest.minimax import Endpoints, MiniMax, minimax_v1 + +from ..compat import has_minimax_key + + +@pytest.mark.external +@pytest.mark.skipif( + has_minimax_key is False, reason="MiniMax API key not available" +) +@pytest.mark.parametrize( + "name", ("MiniMax-M2.5", "MiniMax-M2.5-highspeed") +) +def test_minimax_api_response_is_correct(name: str): + """Check if we're getting the expected response and parsing it properly""" + model = minimax_v1(name=name, config={"temperature": 0.0}) + prompt = "Count the number of characters in this string: hello" + num_prompts = 3 + responses = model(prompts=[[prompt]] * num_prompts) + for response in responses: + assert isinstance(response, list) + assert len(response) == 1 + assert isinstance(response[0], str) + + +@pytest.mark.external +@pytest.mark.skipif( + has_minimax_key is False, reason="MiniMax API key not available" +) +def test_minimax_api_response_when_error(): + """Check if error message shows up properly given incorrect config""" + incorrect_temperature = "one" # should be a float + with pytest.raises(ValueError, match="Request to MiniMax API failed:"): + minimax_v1( + name="MiniMax-M2.5", + config={"temperature": incorrect_temperature}, + ) + + +@pytest.mark.external +@pytest.mark.skipif( + has_minimax_key is False, reason="MiniMax API key not available" +) +def test_minimax_error_unsupported_model(): + """Ensure graceful handling of error when model is not supported""" + incorrect_model = "x-nonexistent-model" + with pytest.raises(ValueError, match="Request to MiniMax API failed:"): + minimax_v1(name=incorrect_model) + + +def test_minimax_context_lengths(): + """Verify context length definitions for MiniMax models""" + ctx = MiniMax._get_context_lengths() + assert "MiniMax-M2.5" in ctx + assert "MiniMax-M2.5-highspeed" in ctx + assert "MiniMax-M2.7" in ctx + assert "MiniMax-M2.7-highspeed" in ctx + assert ctx["MiniMax-M2.5"] == 204800 + assert ctx["MiniMax-M2.7"] == 1048576 + + +def test_minimax_endpoints(): + """Verify endpoint definitions""" + assert Endpoints.CHAT.value == "https://api.minimax.io/v1/chat/completions" diff --git a/usage_examples/textcat_minimax/zeroshot.cfg b/usage_examples/textcat_minimax/zeroshot.cfg new file mode 100644 index 000000000..960d359c6 --- /dev/null +++ b/usage_examples/textcat_minimax/zeroshot.cfg @@ -0,0 +1,23 @@ +[nlp] +lang = "en" +pipeline = ["llm"] +batch_size = 128 + +[components] + +[components.llm] +factory = "llm" + +[components.llm.model] +@llm_models = "spacy.MiniMax.v1" +name = "MiniMax-M2.5" +config = {"temperature": 0.0} + +[components.llm.task] +@llm_tasks = "spacy.TextCat.v2" +labels = COMPLIMENT,INSULT +examples = null +exclusive_classes = true + +[components.llm.task.normalizer] +@misc = "spacy.LowercaseNormalizer.v1"