Skip to content

Commit ff8c876

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent fd740c9 commit ff8c876

4 files changed

Lines changed: 43 additions & 38 deletions

File tree

llm/climate_negotiation/app.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import matplotlib.pyplot as plt
55
import pandas as pd
66
import solara
7+
from climate_negotiation.agents import CountryAgent
8+
from climate_negotiation.model import ClimateNegotiationModel
79
from dotenv import load_dotenv
810
from mesa.visualization import SolaraViz, make_plot_component
911
from mesa.visualization.utils import update_counter
10-
11-
from climate_negotiation.agents import CountryAgent
12-
from climate_negotiation.model import ClimateNegotiationModel
1312
from mesa_llm.reasoning.react import ReActReasoning
1413

1514
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic.main")
@@ -78,7 +77,10 @@ def PledgeBarChart(model):
7877
bar.get_x() + bar.get_width() / 2,
7978
bar.get_height() + 1.2,
8079
f"{pledge:.0f}%",
81-
ha="center", va="bottom", fontsize=9, fontweight="bold",
80+
ha="center",
81+
va="bottom",
82+
fontsize=9,
83+
fontweight="bold",
8284
)
8385

8486
plt.tight_layout()
@@ -105,13 +107,15 @@ def CoalitionStatusPanel(model):
105107
rows = []
106108
for a in sorted(countries, key=lambda x: x.country_name):
107109
coalition = [id_to_name.get(i, str(i)) for i in a.coalition_members]
108-
rows.append({
109-
"Country": a.country_name,
110-
"Pledge": f"{a.current_pledge:.1f}%",
111-
"Accepted": "✓" if a.accepted_treaty else "—",
112-
"Coalition": ", ".join(coalition) or "—",
113-
"Proposals": a.proposals_made,
114-
})
110+
rows.append(
111+
{
112+
"Country": a.country_name,
113+
"Pledge": f"{a.current_pledge:.1f}%",
114+
"Accepted": "✓" if a.accepted_treaty else "—",
115+
"Coalition": ", ".join(coalition) or "—",
116+
"Proposals": a.proposals_made,
117+
}
118+
)
115119

116120
solara.DataFrame(pd.DataFrame(rows))
117121

@@ -133,9 +137,7 @@ def PledgeTrajectoriesChart(model):
133137
return solara.FigureMatplotlib(fig)
134138

135139
id_to_name = {
136-
a.unique_id: a.country_name
137-
for a in model.agents
138-
if isinstance(a, CountryAgent)
140+
a.unique_id: a.country_name for a in model.agents if isinstance(a, CountryAgent)
139141
}
140142

141143
if isinstance(df.index, pd.MultiIndex):
@@ -146,7 +148,9 @@ def PledgeTrajectoriesChart(model):
146148
return solara.FigureMatplotlib(fig)
147149

148150
for country in pledge_df.columns:
149-
ax.plot(pledge_df.index, pledge_df[country], marker="o", label=country, linewidth=2)
151+
ax.plot(
152+
pledge_df.index, pledge_df[country], marker="o", label=country, linewidth=2
153+
)
150154

151155
ax.set_xlabel("Round", fontsize=11)
152156
ax.set_ylabel("Reduction Pledge (%)", fontsize=11)

llm/climate_negotiation/climate_negotiation/agents.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def get_negotiation_history(agent, max_messages: int = 6) -> str:
3939
sender_agent = next(
4040
a for a in agent.model.agents if a.unique_id == sender_id
4141
)
42-
sender_name = getattr(sender_agent, "country_name", f"Agent {sender_id}")
42+
sender_name = getattr(
43+
sender_agent, "country_name", f"Agent {sender_id}"
44+
)
4345
except StopIteration:
4446
sender_name = f"Agent {sender_id}"
4547
messages.append(f" {sender_name}: {msg}")

llm/climate_negotiation/climate_negotiation/model.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import logging
22
import os
3+
34
from mesa.datacollection import DataCollector
45
from mesa.model import Model
6+
from mesa_llm.reasoning.react import ReActReasoning
7+
from mesa_llm.reasoning.reasoning import Reasoning
58
from rich import print as rprint
69

710
from .agents import CountryAgent
8-
from mesa_llm.reasoning.react import ReActReasoning
9-
from mesa_llm.reasoning.reasoning import Reasoning
1011

1112
_log_path = os.environ.get("CLIMATE_LOG_FILE", "climate_negotiation.log")
1213
_file_handler = logging.FileHandler(_log_path, mode="w", encoding="utf-8")
13-
_file_handler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s | %(message)s"))
14+
_file_handler.setFormatter(
15+
logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")
16+
)
1417

1518
sim_logger = logging.getLogger("climate_negotiation")
1619
sim_logger.setLevel(logging.DEBUG)
@@ -162,7 +165,6 @@ def __init__(
162165
},
163166
)
164167

165-
166168
def _treaty_reached(self) -> bool:
167169
"""Return True when at least 2/3 of countries have accepted."""
168170
agents = list(self.agents)
@@ -182,18 +184,12 @@ def _largest_coalition(self) -> int:
182184
"""Size (including self) of the largest active coalition."""
183185
if not self.agents:
184186
return 0
185-
return max(
186-
len(getattr(a, "coalition_members", [])) + 1 for a in self.agents
187-
)
188-
187+
return max(len(getattr(a, "coalition_members", [])) + 1 for a in self.agents)
189188

190189
def step(self):
191190
self.datacollector.collect(self)
192191
round_num = self.steps
193-
rprint(
194-
f"\n[bold cyan]- Climate Summit Round {round_num} "
195-
f"[/bold cyan]"
196-
)
192+
rprint(f"\n[bold cyan]- Climate Summit Round {round_num} [/bold cyan]")
197193
sim_logger.info("=" * 60)
198194
sim_logger.info(f"ROUND {round_num} START")
199195
sim_logger.info("=" * 60)

llm/climate_negotiation/climate_negotiation/tools.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
import logging
1212
from typing import TYPE_CHECKING
1313

14-
from .agents import country_tool_manager
1514
from mesa_llm.tools.tool_decorator import tool
1615

16+
from .agents import country_tool_manager
17+
1718
if TYPE_CHECKING:
1819
from mesa_llm.llm_agent import LLMAgent
1920

@@ -138,10 +139,13 @@ def form_coalition(
138139
"""
139140
if isinstance(partner_ids, str):
140141
import json
142+
141143
try:
142144
partner_ids = json.loads(partner_ids)
143145
except (json.JSONDecodeError, ValueError):
144-
partner_ids = [int(x.strip()) for x in partner_ids.strip("[]").split(",") if x.strip()]
146+
partner_ids = [
147+
int(x.strip()) for x in partner_ids.strip("[]").split(",") if x.strip()
148+
]
145149
partner_ids = [int(pid) for pid in (partner_ids or [])]
146150

147151
# Filter out hallucinated IDs - only keep IDs that map to real agents.
@@ -168,10 +172,9 @@ def form_coalition(
168172
):
169173
partner.coalition_members.append(agent.unique_id)
170174

171-
member_names = (
172-
[getattr(a, "country_name", str(a.unique_id)) for a in partner_agents]
173-
+ [agent.country_name]
174-
)
175+
member_names = [
176+
getattr(a, "country_name", str(a.unique_id)) for a in partner_agents
177+
] + [agent.country_name]
175178

176179
agent.send_message(
177180
f"[COALITION] {agent.country_name} proposes the '{coalition_name}'. "
@@ -207,11 +210,11 @@ def reject_and_counter(
207210
"""
208211
counter_reduction_percent = float(counter_reduction_percent or 20.0)
209212
proposer_id = int(proposer_id or 0)
210-
proposer = next(
211-
(a for a in agent.model.agents if a.unique_id == proposer_id), None
212-
)
213+
proposer = next((a for a in agent.model.agents if a.unique_id == proposer_id), None)
213214
proposer_name = (
214-
getattr(proposer, "country_name", str(proposer_id)) if proposer else str(proposer_id)
215+
getattr(proposer, "country_name", str(proposer_id))
216+
if proposer
217+
else str(proposer_id)
215218
)
216219

217220
counter_msg = (

0 commit comments

Comments
 (0)