Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions tests/test_structured_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@

import pytest

from tradingagents.agents.analysts.sentiment_analyst import create_sentiment_analyst
from tradingagents.agents.managers.research_manager import create_research_manager
from tradingagents.agents.schemas import (
PortfolioRating,
ResearchPlan,
SentimentBand,
SentimentReport,
TraderAction,
TraderProposal,
render_research_plan,
render_sentiment_report,
render_trader_proposal,
)
from tradingagents.agents.trader.trader import create_trader
Expand Down Expand Up @@ -230,3 +234,137 @@ def test_falls_back_to_freetext_when_structured_unavailable(self):
rm = create_research_manager(llm)
result = rm(_make_rm_state())
assert result["investment_plan"] == plain_response


# ---------------------------------------------------------------------------
# SentimentReport schema and render
# ---------------------------------------------------------------------------


@pytest.mark.unit
class TestRenderSentimentReport:
def test_required_fields_present(self):
r = SentimentReport(
overall_score=7.2,
overall_band=SentimentBand.BULLISH,
confidence="high",
narrative="TSLA sentiment is strongly positive driven by retail enthusiasm.",
)
md = render_sentiment_report(r)
assert "**Overall Sentiment:** **Bullish**" in md
assert "Score: 7.2/10" in md
assert "Confidence: high" in md
assert "TSLA sentiment is strongly positive" in md

def test_score_formatted_to_one_decimal(self):
r = SentimentReport(
overall_score=6.0,
overall_band=SentimentBand.MILDLY_BULLISH,
confidence="medium",
narrative="Mixed signals.",
)
md = render_sentiment_report(r)
assert "6.0/10" in md

def test_all_bands_render(self):
for band in SentimentBand:
r = SentimentReport(
overall_score=5.0,
overall_band=band,
confidence="low",
narrative="Test.",
)
md = render_sentiment_report(r)
assert band.value in md

def test_narrative_preserved_intact(self):
narrative = "Line one.\n\nLine two.\n\n| Col | Val |\n|-----|-----|\n| A | B |"
r = SentimentReport(
overall_score=4.5,
overall_band=SentimentBand.NEUTRAL,
confidence="medium",
narrative=narrative,
)
md = render_sentiment_report(r)
assert narrative in md

def test_score_bounds(self):
import pytest as _pytest
from pydantic import ValidationError
with _pytest.raises(ValidationError):
SentimentReport(
overall_score=11.0,
overall_band=SentimentBand.BULLISH,
confidence="high",
narrative="Out of range.",
)
with _pytest.raises(ValidationError):
SentimentReport(
overall_score=-1.0,
overall_band=SentimentBand.BEARISH,
confidence="low",
narrative="Out of range.",
)


# ---------------------------------------------------------------------------
# Sentiment Analyst node
# ---------------------------------------------------------------------------


def _make_analyst_state():
from langchain_core.messages import HumanMessage
return {
"messages": [HumanMessage(content="Analyze TSLA sentiment.")],
"company_of_interest": "TSLA",
"trade_date": "2024-01-15",
}


def _make_analyst_llm(structured_result):
"""LLM mock for the sentiment analyst.

The analyst uses format_messages then invoke_structured_or_freetext, which
calls structured_llm.invoke(formatted_messages) directly.
"""
llm = MagicMock()
llm.with_structured_output.return_value.invoke.return_value = structured_result
return llm


@pytest.mark.unit
class TestSentimentAnalystStructured:
def test_structured_output_populates_sentiment_report(self):
structured = SentimentReport(
overall_score=7.2,
overall_band=SentimentBand.BULLISH,
confidence="high",
narrative="Retail very bullish on TSLA this week.",
)
llm = _make_analyst_llm(structured)
analyst = create_sentiment_analyst(llm)
result = analyst(_make_analyst_state())
assert "Score: 7.2/10" in result["sentiment_report"]
assert "Bullish" in result["sentiment_report"]
assert "Retail very bullish" in result["sentiment_report"]

def test_sentiment_report_in_messages(self):
"""The rendered report should also appear in the messages state."""
structured = SentimentReport(
overall_score=5.0,
overall_band=SentimentBand.NEUTRAL,
confidence="medium",
narrative="Mixed signals this week.",
)
llm = _make_analyst_llm(structured)
analyst = create_sentiment_analyst(llm)
result = analyst(_make_analyst_state())
assert result["messages"][-1].content == result["sentiment_report"]

def test_falls_back_to_freetext_when_structured_unavailable(self):
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("unsupported")
llm.invoke.return_value = MagicMock(content="Raw prose fallback report.")
analyst = create_sentiment_analyst(llm)
result = analyst(_make_analyst_state())
assert "Raw prose fallback report." in result["sentiment_report"]
29 changes: 20 additions & 9 deletions tradingagents/agents/analysts/sentiment_analyst.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@

from datetime import datetime, timedelta

from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.schemas import SentimentReport, render_sentiment_report
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
get_news,
)
from tradingagents.agents.utils.structured import bind_structured, invoke_structured_or_freetext
from tradingagents.dataflows.reddit import fetch_reddit_posts
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages

Expand All @@ -39,9 +42,11 @@ def create_sentiment_analyst(llm):
"""Create a sentiment analyst node for the trading graph.

Pre-fetches news + StockTwits + Reddit data, injects them into the
prompt as structured blocks, and produces a sentiment report in a
single LLM call.
prompt as structured blocks, and produces a structured sentiment report
in a single LLM call. Falls back to free-text when the provider does
not support structured output.
"""
structured_llm = bind_structured(llm, SentimentReport, "Sentiment Analyst")

def sentiment_analyst_node(state):
ticker = state["company_of_interest"]
Expand All @@ -51,7 +56,7 @@ def sentiment_analyst_node(state):

# Pre-fetch all three sources. Each fetcher degrades gracefully and
# returns a string (no exceptions surface from here), so the LLM
# always sees something either real data or a clear placeholder.
# always sees something - either real data or a clear placeholder.
news_block = get_news.func(ticker, start_date, end_date)
stocktwits_block = fetch_stocktwits_messages(ticker, limit=30)
reddit_block = fetch_reddit_posts(ticker)
Expand Down Expand Up @@ -83,14 +88,20 @@ def sentiment_analyst_node(state):
prompt = prompt.partial(current_date=end_date)
prompt = prompt.partial(instrument_context=instrument_context)

# No bind_tools — the data is already in the prompt; a single LLM
# call produces the report directly.
chain = prompt | llm
result = chain.invoke(state["messages"])
# No bind_tools - the data is already in the prompt. Use structured
# output directly since there is no tool-calling phase to work around.
formatted = prompt.format_messages(messages=state["messages"])
report = invoke_structured_or_freetext(
structured_llm,
llm,
formatted,
render_sentiment_report,
"Sentiment Analyst",
)

return {
"messages": [result],
"sentiment_report": result.content,
"messages": [AIMessage(content=report)],
"sentiment_report": report,
}

return sentiment_analyst_node
Expand Down
77 changes: 76 additions & 1 deletion tradingagents/agents/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from enum import Enum
from typing import Optional
from typing import Literal, Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -226,3 +226,78 @@ def render_pm_decision(decision: PortfolioDecision) -> str:
if decision.time_horizon:
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
return "\n".join(parts)


# ---------------------------------------------------------------------------
# Social Media / Sentiment Analyst
# ---------------------------------------------------------------------------


class SentimentBand(str, Enum):
"""Categorical sentiment label used by the Social Media Analyst."""

BULLISH = "Bullish"
MILDLY_BULLISH = "Mildly Bullish"
NEUTRAL = "Neutral"
MIXED = "Mixed"
MILDLY_BEARISH = "Mildly Bearish"
BEARISH = "Bearish"


class SentimentReport(BaseModel):
"""Structured sentiment output produced by the Social Media Analyst.

The analyst collects data via tool calls (get_news), writes a free-form
research narrative, then this schema is used to extract deterministic
summary fields from that narrative. The narrative field preserves the
full prose so downstream debate agents receive the same rich context
they always have.
"""

overall_score: float = Field(
ge=0,
le=10,
description=(
"Numeric sentiment score from 0 (most bearish) to 10 (most bullish). "
"Use one decimal place, e.g. 7.2."
),
)
overall_band: SentimentBand = Field(
description=(
"Categorical sentiment label. Pick the band that best matches the "
"overall_score: Bullish (7.1-10), Mildly Bullish (5.6-7.0), Neutral "
"(4.5-5.5, flat outlook with no strong signal), Mixed "
"(4.5-5.5, high-volume but conflicting signals), "
"Mildly Bearish (3.0-4.4), Bearish (0-2.9)."
),
)
Comment on lines +265 to +273
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The description for overall_band contains overlapping numeric ranges (e.g., 7.0 is in both Bullish and Mildly Bullish) and groups Neutral/Mixed together even though they are separate members in the SentimentBand enum. Providing non-overlapping ranges and a clear distinction between 'Neutral' and 'Mixed' will help the LLM select the correct enum value more reliably.

'Neutral' typically implies a lack of strong signal or a flat outlook, while 'Mixed' implies high-volume but conflicting signals.

    overall_band: SentimentBand = Field(
        description=(
            "Categorical sentiment label. Pick the band that best matches the "
            "overall_score: Bullish (7.1-10), Mildly Bullish (5.6-7.0), Neutral "
            "(4.5-5.5, flat outlook), Mixed (4.5-5.5, conflicting signals), "
            "Mildly Bearish (3.0-4.4), Bearish (0-2.9)."
        ),
    )

Copy link
Copy Markdown
Author

@HrushiYadav HrushiYadav May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 1ec8997 - non-overlapping ranges applied and Neutral vs Mixed distinction clarified.

confidence: Literal["low", "medium", "high"] = Field(
description=(
"Confidence in the sentiment assessment based on data availability and "
"signal clarity. Use 'low' when sources are sparse or contradictory, "
"'medium' for moderate signal, 'high' for strong consistent signal."
),
)
narrative: str = Field(
description=(
"The full analyst report in markdown. Include all research findings, "
"social media analysis, news insights, and a summary table. "
"This is what downstream agents read as context."
),
)


def render_sentiment_report(report: SentimentReport) -> str:
"""Render a SentimentReport to markdown for state storage and downstream agents.

The narrative (full prose report) is preserved intact so bull/bear researchers
and risk debate agents continue to receive the same rich context. The structured
fields are appended as a summary header for quick downstream parsing.
"""
header = "\n".join([
f"**Overall Sentiment:** **{report.overall_band.value}** "
f"(Score: {report.overall_score:.1f}/10, Confidence: {report.confidence})",
"",
"",
])
Comment on lines +297 to +302
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of render_sentiment_report results in only a single newline between the structured header and the narrative text. In Markdown, a blank line (two newlines) is standard for separating blocks of text. Adding an additional empty string to the join list will ensure proper spacing for downstream display.

Suggested change
header = "\n".join([
f"**Overall Sentiment:** **{report.overall_band.value}** "
f"(Score: {report.overall_score:.1f}/10, Confidence: {report.confidence})",
"",
])
header = "\n".join([
f"**Overall Sentiment:** **{report.overall_band.value}** "
f"(Score: {report.overall_score:.1f}/10, Confidence: {report.confidence})",
"",
"",
])

Copy link
Copy Markdown
Author

@HrushiYadav HrushiYadav May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 1ec8997 - added the extra blank line for proper markdown block separation.

return header + report.narrative