Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ This package integrates Large Language Models (LLMs) into [spaCy](https://spacy.
- **[Anthropic](https://docs.anthropic.com/claude/reference/)**
- **[Google PaLM](https://ai.google/discover/palm2/)**
- **[Microsoft Azure AI](https://azure.microsoft.com/en-us/solutions/ai)**
- **[MiniMax](https://platform.minimaxi.com/)**
- Supports open-source LLMs hosted on Hugging Face 🤗:
- **[Falcon](https://huggingface.co/tiiuae)**
- **[Dolly](https://huggingface.co/databricks)**
Expand Down
3 changes: 2 additions & 1 deletion spacy_llm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .hf import dolly_hf, openllama_hf, stablelm_hf
from .langchain import query_langchain
from .rest import anthropic, cohere, noop, openai, palm
from .rest import anthropic, cohere, minimax, noop, openai, palm

__all__ = [
"anthropic",
"cohere",
"minimax",
"openai",
"dolly_hf",
"noop",
Expand Down
3 changes: 2 additions & 1 deletion spacy_llm/models/rest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from . import anthropic, azure, base, cohere, noop, openai
from . import anthropic, azure, base, cohere, minimax, noop, openai

__all__ = [
"anthropic",
"azure",
"base",
"cohere",
"minimax",
"openai",
"noop",
]
4 changes: 4 additions & 0 deletions spacy_llm/models/rest/minimax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .model import Endpoints, MiniMax
from .registry import minimax_v1

__all__ = ["MiniMax", "Endpoints", "minimax_v1"]
119 changes: 119 additions & 0 deletions spacy_llm/models/rest/minimax/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import re
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
from requests import HTTPError

from ..base import REST


class Endpoints(str, Enum):
CHAT = "https://api.minimax.io/v1/chat/completions"


class MiniMax(REST):
@property
def credentials(self) -> Dict[str, str]:
api_key = os.getenv("MINIMAX_API_KEY")
if api_key is None:
warnings.warn(
"Could not find the API key to access the MiniMax API. Ensure you have an API key "
"set up via https://platform.minimaxi.com/, then make it available as "
"an environment variable 'MINIMAX_API_KEY'."
)

return {
"Authorization": f"Bearer {api_key}",
}

def _verify_auth(self) -> None:
# MiniMax doesn't have a dedicated /v1/models endpoint, so we verify
# auth by making a minimal test request.
try:
self([["test"]])
except ValueError as err:
if "authentication" in str(err).lower() or "api key" in str(err).lower():
warnings.warn(
"Authentication with provided API key failed. Please double-check you provided "
"the correct credentials."
)
else:
raise err

def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]:
headers = {
**self._credentials,
"Content-Type": "application/json",
}
all_api_responses: List[List[str]] = []

for prompts_for_doc in prompts:
api_responses: List[str] = []
prompts_for_doc = list(prompts_for_doc)

def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
r = self.retry(
call_method=requests.post,
url=self._endpoint,
headers=headers,
json={**json_data, **self._config, "model": self._name},
timeout=self._max_request_time,
)
try:
r.raise_for_status()
except HTTPError as ex:
res_content = srsly.json_loads(r.content.decode("utf-8"))
raise ValueError(
f"Request to MiniMax API failed: "
f"{res_content.get('error', {}).get('message', str(res_content))}"
) from ex
responses = r.json()

if "error" in responses:
if self._strict:
raise ValueError(f"API call failed: {responses}.")
else:
assert isinstance(prompts_for_doc, Sized)
return {
"error": [srsly.json_dumps(responses)]
* len(prompts_for_doc)
}

return responses

# MiniMax uses an OpenAI-compatible chat completions API,
# so we send individual requests per prompt (no batching).
for prompt in prompts_for_doc:
responses = _request(
{"messages": [{"role": "user", "content": prompt}]}
)
if "error" in responses:
return responses["error"]

assert len(responses["choices"]) == 1
response = responses["choices"][0]
content = response.get("message", {}).get(
"content", srsly.json_dumps(response)
)
# Strip <think>...</think> tags from thinking models.
content = re.sub(
r"<think>.*?</think>\s*", "", content, flags=re.DOTALL
)
api_responses.append(content)

all_api_responses.append(api_responses)

return all_api_responses

@staticmethod
def _get_context_lengths() -> Dict[str, int]:
return {
"MiniMax-M2.7": 1048576,
"MiniMax-M2.7-highspeed": 1048576,
"MiniMax-M2.5": 204800,
"MiniMax-M2.5-highspeed": 204800,
}
50 changes: 50 additions & 0 deletions spacy_llm/models/rest/minimax/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Dict, Optional

from confection import SimpleFrozenDict

from ....registry import registry
from .model import Endpoints, MiniMax

_DEFAULT_TEMPERATURE = 0.0


@registry.llm_models("spacy.MiniMax.v1")
def minimax_v1(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
name: str = "MiniMax-M2.5",
strict: bool = MiniMax.DEFAULT_STRICT,
max_tries: int = MiniMax.DEFAULT_MAX_TRIES,
interval: float = MiniMax.DEFAULT_INTERVAL,
max_request_time: float = MiniMax.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
) -> MiniMax:
"""Returns MiniMax instance for MiniMax models using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the MiniMax API - e.g.
'MiniMax-M2.5', 'MiniMax-M2.5-highspeed', 'MiniMax-M2.7', 'MiniMax-M2.7-highspeed'.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response.
If False, the API error responses are returned by __call__(), but no error will be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2
exponential backoff at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising
an exception.
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and
if no context length natively provided by spacy-llm.
RETURNS (MiniMax): MiniMax instance.

DOCS: https://platform.minimaxi.com/
"""
return MiniMax(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)
1 change: 1 addition & 0 deletions spacy_llm/tests/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
has_cohere_key = os.getenv("CO_API_KEY") is not None
has_azure_openai_key = os.getenv("AZURE_OPENAI_KEY") is not None
has_palm_key = os.getenv("PALM_API_KEY") is not None
has_minimax_key = os.getenv("MINIMAX_API_KEY") is not None
66 changes: 66 additions & 0 deletions spacy_llm/tests/models/test_minimax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# mypy: ignore-errors
import pytest

from spacy_llm.models.rest.minimax import Endpoints, MiniMax, minimax_v1

from ..compat import has_minimax_key


@pytest.mark.external
@pytest.mark.skipif(
has_minimax_key is False, reason="MiniMax API key not available"
)
@pytest.mark.parametrize(
"name", ("MiniMax-M2.5", "MiniMax-M2.5-highspeed")
)
def test_minimax_api_response_is_correct(name: str):
"""Check if we're getting the expected response and parsing it properly"""
model = minimax_v1(name=name, config={"temperature": 0.0})
prompt = "Count the number of characters in this string: hello"
num_prompts = 3
responses = model(prompts=[[prompt]] * num_prompts)
for response in responses:
assert isinstance(response, list)
assert len(response) == 1
assert isinstance(response[0], str)


@pytest.mark.external
@pytest.mark.skipif(
has_minimax_key is False, reason="MiniMax API key not available"
)
def test_minimax_api_response_when_error():
"""Check if error message shows up properly given incorrect config"""
incorrect_temperature = "one" # should be a float
with pytest.raises(ValueError, match="Request to MiniMax API failed:"):
minimax_v1(
name="MiniMax-M2.5",
config={"temperature": incorrect_temperature},
)


@pytest.mark.external
@pytest.mark.skipif(
has_minimax_key is False, reason="MiniMax API key not available"
)
def test_minimax_error_unsupported_model():
"""Ensure graceful handling of error when model is not supported"""
incorrect_model = "x-nonexistent-model"
with pytest.raises(ValueError, match="Request to MiniMax API failed:"):
minimax_v1(name=incorrect_model)


def test_minimax_context_lengths():
"""Verify context length definitions for MiniMax models"""
ctx = MiniMax._get_context_lengths()
assert "MiniMax-M2.5" in ctx
assert "MiniMax-M2.5-highspeed" in ctx
assert "MiniMax-M2.7" in ctx
assert "MiniMax-M2.7-highspeed" in ctx
assert ctx["MiniMax-M2.5"] == 204800
assert ctx["MiniMax-M2.7"] == 1048576


def test_minimax_endpoints():
"""Verify endpoint definitions"""
assert Endpoints.CHAT.value == "https://api.minimax.io/v1/chat/completions"
23 changes: 23 additions & 0 deletions usage_examples/textcat_minimax/zeroshot.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[nlp]
lang = "en"
pipeline = ["llm"]
batch_size = 128

[components]

[components.llm]
factory = "llm"

[components.llm.model]
@llm_models = "spacy.MiniMax.v1"
name = "MiniMax-M2.5"
config = {"temperature": 0.0}

[components.llm.task]
@llm_tasks = "spacy.TextCat.v2"
labels = COMPLIMENT,INSULT
examples = null
exclusive_classes = true

[components.llm.task.normalizer]
@misc = "spacy.LowercaseNormalizer.v1"