diff --git a/tests/test_structured_agents.py b/tests/test_structured_agents.py index ea771a4b03..1b9c8b4333 100644 --- a/tests/test_structured_agents.py +++ b/tests/test_structured_agents.py @@ -11,13 +11,17 @@ import pytest +from tradingagents.agents.analysts.sentiment_analyst import create_sentiment_analyst from tradingagents.agents.managers.research_manager import create_research_manager from tradingagents.agents.schemas import ( PortfolioRating, ResearchPlan, + SentimentBand, + SentimentReport, TraderAction, TraderProposal, render_research_plan, + render_sentiment_report, render_trader_proposal, ) from tradingagents.agents.trader.trader import create_trader @@ -230,3 +234,137 @@ def test_falls_back_to_freetext_when_structured_unavailable(self): rm = create_research_manager(llm) result = rm(_make_rm_state()) assert result["investment_plan"] == plain_response + + +# --------------------------------------------------------------------------- +# SentimentReport schema and render +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestRenderSentimentReport: + def test_required_fields_present(self): + r = SentimentReport( + overall_score=7.2, + overall_band=SentimentBand.BULLISH, + confidence="high", + narrative="TSLA sentiment is strongly positive driven by retail enthusiasm.", + ) + md = render_sentiment_report(r) + assert "**Overall Sentiment:** **Bullish**" in md + assert "Score: 7.2/10" in md + assert "Confidence: high" in md + assert "TSLA sentiment is strongly positive" in md + + def test_score_formatted_to_one_decimal(self): + r = SentimentReport( + overall_score=6.0, + overall_band=SentimentBand.MILDLY_BULLISH, + confidence="medium", + narrative="Mixed signals.", + ) + md = render_sentiment_report(r) + assert "6.0/10" in md + + def test_all_bands_render(self): + for band in SentimentBand: + r = SentimentReport( + overall_score=5.0, + overall_band=band, + confidence="low", + narrative="Test.", + ) + md = render_sentiment_report(r) + assert band.value in md + + def test_narrative_preserved_intact(self): + narrative = "Line one.\n\nLine two.\n\n| Col | Val |\n|-----|-----|\n| A | B |" + r = SentimentReport( + overall_score=4.5, + overall_band=SentimentBand.NEUTRAL, + confidence="medium", + narrative=narrative, + ) + md = render_sentiment_report(r) + assert narrative in md + + def test_score_bounds(self): + import pytest as _pytest + from pydantic import ValidationError + with _pytest.raises(ValidationError): + SentimentReport( + overall_score=11.0, + overall_band=SentimentBand.BULLISH, + confidence="high", + narrative="Out of range.", + ) + with _pytest.raises(ValidationError): + SentimentReport( + overall_score=-1.0, + overall_band=SentimentBand.BEARISH, + confidence="low", + narrative="Out of range.", + ) + + +# --------------------------------------------------------------------------- +# Sentiment Analyst node +# --------------------------------------------------------------------------- + + +def _make_analyst_state(): + from langchain_core.messages import HumanMessage + return { + "messages": [HumanMessage(content="Analyze TSLA sentiment.")], + "company_of_interest": "TSLA", + "trade_date": "2024-01-15", + } + + +def _make_analyst_llm(structured_result): + """LLM mock for the sentiment analyst. + + The analyst uses format_messages then invoke_structured_or_freetext, which + calls structured_llm.invoke(formatted_messages) directly. + """ + llm = MagicMock() + llm.with_structured_output.return_value.invoke.return_value = structured_result + return llm + + +@pytest.mark.unit +class TestSentimentAnalystStructured: + def test_structured_output_populates_sentiment_report(self): + structured = SentimentReport( + overall_score=7.2, + overall_band=SentimentBand.BULLISH, + confidence="high", + narrative="Retail very bullish on TSLA this week.", + ) + llm = _make_analyst_llm(structured) + analyst = create_sentiment_analyst(llm) + result = analyst(_make_analyst_state()) + assert "Score: 7.2/10" in result["sentiment_report"] + assert "Bullish" in result["sentiment_report"] + assert "Retail very bullish" in result["sentiment_report"] + + def test_sentiment_report_in_messages(self): + """The rendered report should also appear in the messages state.""" + structured = SentimentReport( + overall_score=5.0, + overall_band=SentimentBand.NEUTRAL, + confidence="medium", + narrative="Mixed signals this week.", + ) + llm = _make_analyst_llm(structured) + analyst = create_sentiment_analyst(llm) + result = analyst(_make_analyst_state()) + assert result["messages"][-1].content == result["sentiment_report"] + + def test_falls_back_to_freetext_when_structured_unavailable(self): + llm = MagicMock() + llm.with_structured_output.side_effect = NotImplementedError("unsupported") + llm.invoke.return_value = MagicMock(content="Raw prose fallback report.") + analyst = create_sentiment_analyst(llm) + result = analyst(_make_analyst_state()) + assert "Raw prose fallback report." in result["sentiment_report"] diff --git a/tradingagents/agents/analysts/sentiment_analyst.py b/tradingagents/agents/analysts/sentiment_analyst.py index e1e4ee4f41..223f312c47 100644 --- a/tradingagents/agents/analysts/sentiment_analyst.py +++ b/tradingagents/agents/analysts/sentiment_analyst.py @@ -21,12 +21,15 @@ from datetime import datetime, timedelta +from langchain_core.messages import AIMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.schemas import SentimentReport, render_sentiment_report from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_language_instruction, get_news, ) +from tradingagents.agents.utils.structured import bind_structured, invoke_structured_or_freetext from tradingagents.dataflows.reddit import fetch_reddit_posts from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages @@ -39,9 +42,11 @@ def create_sentiment_analyst(llm): """Create a sentiment analyst node for the trading graph. Pre-fetches news + StockTwits + Reddit data, injects them into the - prompt as structured blocks, and produces a sentiment report in a - single LLM call. + prompt as structured blocks, and produces a structured sentiment report + in a single LLM call. Falls back to free-text when the provider does + not support structured output. """ + structured_llm = bind_structured(llm, SentimentReport, "Sentiment Analyst") def sentiment_analyst_node(state): ticker = state["company_of_interest"] @@ -51,7 +56,7 @@ def sentiment_analyst_node(state): # Pre-fetch all three sources. Each fetcher degrades gracefully and # returns a string (no exceptions surface from here), so the LLM - # always sees something — either real data or a clear placeholder. + # always sees something - either real data or a clear placeholder. news_block = get_news.func(ticker, start_date, end_date) stocktwits_block = fetch_stocktwits_messages(ticker, limit=30) reddit_block = fetch_reddit_posts(ticker) @@ -83,14 +88,20 @@ def sentiment_analyst_node(state): prompt = prompt.partial(current_date=end_date) prompt = prompt.partial(instrument_context=instrument_context) - # No bind_tools — the data is already in the prompt; a single LLM - # call produces the report directly. - chain = prompt | llm - result = chain.invoke(state["messages"]) + # No bind_tools - the data is already in the prompt. Use structured + # output directly since there is no tool-calling phase to work around. + formatted = prompt.format_messages(messages=state["messages"]) + report = invoke_structured_or_freetext( + structured_llm, + llm, + formatted, + render_sentiment_report, + "Sentiment Analyst", + ) return { - "messages": [result], - "sentiment_report": result.content, + "messages": [AIMessage(content=report)], + "sentiment_report": report, } return sentiment_analyst_node diff --git a/tradingagents/agents/schemas.py b/tradingagents/agents/schemas.py index 55f0e3cfb1..48a79a495e 100644 --- a/tradingagents/agents/schemas.py +++ b/tradingagents/agents/schemas.py @@ -19,7 +19,7 @@ from __future__ import annotations from enum import Enum -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel, Field @@ -226,3 +226,78 @@ def render_pm_decision(decision: PortfolioDecision) -> str: if decision.time_horizon: parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"]) return "\n".join(parts) + + +# --------------------------------------------------------------------------- +# Social Media / Sentiment Analyst +# --------------------------------------------------------------------------- + + +class SentimentBand(str, Enum): + """Categorical sentiment label used by the Social Media Analyst.""" + + BULLISH = "Bullish" + MILDLY_BULLISH = "Mildly Bullish" + NEUTRAL = "Neutral" + MIXED = "Mixed" + MILDLY_BEARISH = "Mildly Bearish" + BEARISH = "Bearish" + + +class SentimentReport(BaseModel): + """Structured sentiment output produced by the Social Media Analyst. + + The analyst collects data via tool calls (get_news), writes a free-form + research narrative, then this schema is used to extract deterministic + summary fields from that narrative. The narrative field preserves the + full prose so downstream debate agents receive the same rich context + they always have. + """ + + overall_score: float = Field( + ge=0, + le=10, + description=( + "Numeric sentiment score from 0 (most bearish) to 10 (most bullish). " + "Use one decimal place, e.g. 7.2." + ), + ) + overall_band: SentimentBand = Field( + description=( + "Categorical sentiment label. Pick the band that best matches the " + "overall_score: Bullish (7.1-10), Mildly Bullish (5.6-7.0), Neutral " + "(4.5-5.5, flat outlook with no strong signal), Mixed " + "(4.5-5.5, high-volume but conflicting signals), " + "Mildly Bearish (3.0-4.4), Bearish (0-2.9)." + ), + ) + confidence: Literal["low", "medium", "high"] = Field( + description=( + "Confidence in the sentiment assessment based on data availability and " + "signal clarity. Use 'low' when sources are sparse or contradictory, " + "'medium' for moderate signal, 'high' for strong consistent signal." + ), + ) + narrative: str = Field( + description=( + "The full analyst report in markdown. Include all research findings, " + "social media analysis, news insights, and a summary table. " + "This is what downstream agents read as context." + ), + ) + + +def render_sentiment_report(report: SentimentReport) -> str: + """Render a SentimentReport to markdown for state storage and downstream agents. + + The narrative (full prose report) is preserved intact so bull/bear researchers + and risk debate agents continue to receive the same rich context. The structured + fields are appended as a summary header for quick downstream parsing. + """ + header = "\n".join([ + f"**Overall Sentiment:** **{report.overall_band.value}** " + f"(Score: {report.overall_score:.1f}/10, Confidence: {report.confidence})", + "", + "", + ]) + return header + report.narrative