Skip to content

Commit 561688e

Browse files
ericapisaniclaude
andcommitted
feat(langchain): Broaden AI provider detection beyond OpenAI and Anthropic
Extract _get_ai_system() to generically detect AI providers from LangChain's _type field instead of hardcoding only "anthropic" and "openai". The function splits on "-" and skips non-provider segments (cloud prefixes like "azure" and descriptors like "chat"/"llm") to return the actual provider name. This adds support for Cohere, Ollama, Mistral, Fireworks, HuggingFace, Groq, NVIDIA, xAI, DeepSeek, Google, and any future LangChain providers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dc65e13 commit 561688e

File tree

2 files changed

+119
-10
lines changed

2 files changed

+119
-10
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,31 @@
108108
OllamaEmbeddings = None
109109

110110

111+
_NON_PROVIDER_PARTS = frozenset({"azure", "aws", "gcp", "vertex", "chat", "llm"})
112+
113+
114+
def _get_ai_system(all_params: "Dict[str, Any]") -> "Optional[str]":
115+
"""Extract the AI provider from the ``_type`` field in LangChain params.
116+
117+
Splits on ``-`` and skips generic segments (cloud prefixes and model-type
118+
descriptors like ``chat`` / ``llm``) to return the actual provider name.
119+
"""
120+
ai_type = all_params.get("_type")
121+
122+
if not ai_type or not isinstance(ai_type, str):
123+
return None
124+
125+
parts = [p.strip().lower() for p in ai_type.split("-") if p.strip()]
126+
if not parts:
127+
return None
128+
129+
for part in parts:
130+
if part not in _NON_PROVIDER_PARTS:
131+
return part
132+
133+
return parts[0]
134+
135+
111136
DATA_FIELDS = {
112137
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
113138
"function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
@@ -381,11 +406,9 @@ def on_llm_start(
381406
model,
382407
)
383408

384-
ai_type = all_params.get("_type", "")
385-
if "anthropic" in ai_type:
386-
span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic")
387-
elif "openai" in ai_type:
388-
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
409+
ai_system = _get_ai_system(all_params)
410+
if ai_system:
411+
span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system)
389412

390413
for key, attribute in DATA_FIELDS.items():
391414
if key in all_params and all_params[key] is not None:
@@ -449,11 +472,9 @@ def on_chat_model_start(
449472
if model:
450473
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
451474

452-
ai_type = all_params.get("_type", "")
453-
if "anthropic" in ai_type:
454-
span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic")
455-
elif "openai" in ai_type:
456-
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
475+
ai_system = _get_ai_system(all_params)
476+
if ai_system:
477+
span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system)
457478

458479
agent_name = _get_current_agent()
459480
if agent_name:

tests/integrations/langchain/test_langchain.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,94 @@ def test_transform_google_file_data(self):
20002000
}
20012001

20022002

2003+
@pytest.mark.parametrize(
2004+
"ai_type,expected_system",
2005+
[
2006+
# Real LangChain _type values (from _llm_type properties)
2007+
# OpenAI
2008+
("openai-chat", "openai"),
2009+
("openai", "openai"),
2010+
# Azure OpenAI
2011+
("azure-openai-chat", "openai"),
2012+
("azure", "azure"),
2013+
# Anthropic
2014+
("anthropic-chat", "anthropic"),
2015+
# Google
2016+
("vertexai", "vertexai"),
2017+
("chat-google-generative-ai", "google"),
2018+
("google_gemini", "google_gemini"),
2019+
# AWS Bedrock (underscore-separated, no split)
2020+
("amazon_bedrock_chat", "amazon_bedrock_chat"),
2021+
("amazon_bedrock", "amazon_bedrock"),
2022+
# Cohere
2023+
("cohere-chat", "cohere"),
2024+
# Ollama
2025+
("chat-ollama", "ollama"),
2026+
("ollama-llm", "ollama"),
2027+
# Mistral
2028+
("mistralai-chat", "mistralai"),
2029+
# Fireworks
2030+
("fireworks-chat", "fireworks"),
2031+
("fireworks", "fireworks"),
2032+
# HuggingFace
2033+
("huggingface-chat-wrapper", "huggingface"),
2034+
# Groq
2035+
("groq-chat", "groq"),
2036+
# NVIDIA
2037+
("chat-nvidia-ai-playground", "nvidia"),
2038+
# xAI
2039+
("xai-chat", "xai"),
2040+
# DeepSeek
2041+
("chat-deepseek", "deepseek"),
2042+
# Edge cases
2043+
("", None),
2044+
(None, None),
2045+
],
2046+
)
2047+
def test_langchain_ai_system_detection(
2048+
sentry_init, capture_events, ai_type, expected_system
2049+
):
2050+
sentry_init(
2051+
integrations=[LangchainIntegration()],
2052+
traces_sample_rate=1.0,
2053+
)
2054+
events = capture_events()
2055+
2056+
callback = SentryLangchainCallback(max_span_map_size=100, include_prompts=True)
2057+
2058+
run_id = "test-ai-system-uuid"
2059+
serialized = {"_type": ai_type} if ai_type is not None else {}
2060+
prompts = ["Test prompt"]
2061+
2062+
with start_transaction():
2063+
callback.on_llm_start(
2064+
serialized=serialized,
2065+
prompts=prompts,
2066+
run_id=run_id,
2067+
invocation_params={"_type": ai_type, "model": "test-model"},
2068+
)
2069+
2070+
generation = Mock(text="Test response", message=None)
2071+
response = Mock(generations=[[generation]])
2072+
callback.on_llm_end(response=response, run_id=run_id)
2073+
2074+
assert len(events) > 0
2075+
tx = events[0]
2076+
assert tx["type"] == "transaction"
2077+
2078+
llm_spans = [
2079+
span for span in tx.get("spans", []) if span.get("op") == "gen_ai.pipeline"
2080+
]
2081+
assert len(llm_spans) > 0
2082+
2083+
llm_span = llm_spans[0]
2084+
2085+
if expected_system is not None:
2086+
assert llm_span["data"][SPANDATA.GEN_AI_SYSTEM] == expected_system
2087+
else:
2088+
assert SPANDATA.GEN_AI_SYSTEM not in llm_span.get("data", {})
2089+
2090+
20032091
class TestTransformLangchainMessageContent:
20042092
"""Tests for _transform_langchain_message_content function."""
20052093

0 commit comments

Comments
 (0)