Skip to content

Commit 93b0434

Browse files
authored
Merge branch 'main' into feat/flush-traces-docs
2 parents c1f0071 + 7a3f6b7 commit 93b0434

9 files changed

Lines changed: 849 additions & 532 deletions

File tree

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@
203203
add_trace_processor,
204204
agent_span,
205205
custom_span,
206+
flush_traces,
206207
function_span,
207208
gen_span_id,
208209
gen_trace_id,
@@ -451,6 +452,7 @@ def enable_verbose_stdout_logging():
451452
"add_trace_processor",
452453
"agent_span",
453454
"custom_span",
455+
"flush_traces",
454456
"function_span",
455457
"generation_span",
456458
"get_current_span",

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 487 additions & 466 deletions
Large diffs are not rendered by default.

src/agents/memory/sqlite_session.py

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import json
55
import sqlite3
66
import threading
7+
from collections.abc import Iterator
8+
from contextlib import contextmanager
79
from pathlib import Path
10+
from typing import ClassVar
811

912
from ..items import TResponseInputItem
1013
from .session import SessionABC
@@ -20,6 +23,9 @@ class SQLiteSession(SessionABC):
2023
"""
2124

2225
session_settings: SessionSettings | None = None
26+
_file_locks: ClassVar[dict[Path, threading.RLock]] = {}
27+
_file_lock_counts: ClassVar[dict[Path, int]] = {}
28+
_file_locks_guard: ClassVar[threading.Lock] = threading.Lock()
2329

2430
def __init__(
2531
self,
@@ -46,21 +52,66 @@ def __init__(
4652
self.sessions_table = sessions_table
4753
self.messages_table = messages_table
4854
self._local = threading.local()
49-
self._lock = threading.Lock()
5055

5156
# For in-memory databases, we need a shared connection to avoid thread isolation
5257
# For file databases, we use thread-local connections for better concurrency
5358
self._is_memory_db = str(db_path) == ":memory:"
59+
self._lock_path: Path | None = None
60+
self._lock_released = False
5461
if self._is_memory_db:
55-
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
56-
self._shared_connection.execute("PRAGMA journal_mode=WAL")
57-
self._init_db_for_connection(self._shared_connection)
62+
self._lock = threading.RLock()
5863
else:
59-
# For file databases, initialize the schema once since it persists
60-
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
61-
init_conn.execute("PRAGMA journal_mode=WAL")
62-
self._init_db_for_connection(init_conn)
63-
init_conn.close()
64+
self._lock_path, self._lock = self._acquire_file_lock(Path(self.db_path))
65+
66+
try:
67+
if self._is_memory_db:
68+
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
69+
self._shared_connection.execute("PRAGMA journal_mode=WAL")
70+
self._init_db_for_connection(self._shared_connection)
71+
else:
72+
# For file databases, initialize the schema once since it persists
73+
with self._lock:
74+
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
75+
init_conn.execute("PRAGMA journal_mode=WAL")
76+
self._init_db_for_connection(init_conn)
77+
init_conn.close()
78+
except Exception:
79+
if self._lock_path is not None and not self._lock_released:
80+
self._release_file_lock(self._lock_path)
81+
self._lock_released = True
82+
raise
83+
84+
@classmethod
85+
def _acquire_file_lock(cls, db_path: Path) -> tuple[Path, threading.RLock]:
86+
"""Return the path key and process-local lock for sessions sharing one SQLite file."""
87+
lock_path = db_path.expanduser().resolve()
88+
with cls._file_locks_guard:
89+
lock = cls._file_locks.get(lock_path)
90+
if lock is None:
91+
lock = threading.RLock()
92+
cls._file_locks[lock_path] = lock
93+
cls._file_lock_counts[lock_path] = 0
94+
cls._file_lock_counts[lock_path] += 1
95+
return lock_path, lock
96+
97+
@classmethod
98+
def _release_file_lock(cls, lock_path: Path) -> None:
99+
"""Drop the shared lock for a file-backed DB once the last session closes."""
100+
with cls._file_locks_guard:
101+
ref_count = cls._file_lock_counts.get(lock_path)
102+
if ref_count is None:
103+
return
104+
if ref_count <= 1:
105+
cls._file_lock_counts.pop(lock_path, None)
106+
cls._file_locks.pop(lock_path, None)
107+
else:
108+
cls._file_lock_counts[lock_path] = ref_count - 1
109+
110+
@contextmanager
111+
def _locked_connection(self) -> Iterator[sqlite3.Connection]:
112+
"""Serialize sqlite3 access while each operation runs in a worker thread."""
113+
with self._lock:
114+
yield self._get_connection()
64115

65116
def _get_connection(self) -> sqlite3.Connection:
66117
"""Get a database connection."""
@@ -114,6 +165,31 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
114165

115166
conn.commit()
116167

168+
def _insert_items(self, conn: sqlite3.Connection, items: list[TResponseInputItem]) -> None:
169+
conn.execute(
170+
f"""
171+
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
172+
""",
173+
(self.session_id,),
174+
)
175+
176+
message_data = [(self.session_id, json.dumps(item)) for item in items]
177+
conn.executemany(
178+
f"""
179+
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
180+
""",
181+
message_data,
182+
)
183+
184+
conn.execute(
185+
f"""
186+
UPDATE {self.sessions_table}
187+
SET updated_at = CURRENT_TIMESTAMP
188+
WHERE session_id = ?
189+
""",
190+
(self.session_id,),
191+
)
192+
117193
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
118194
"""Retrieve the conversation history for this session.
119195
@@ -127,8 +203,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
127203
session_limit = resolve_session_limit(limit, self.session_settings)
128204

129205
def _get_items_sync():
130-
conn = self._get_connection()
131-
with self._lock if self._is_memory_db else threading.Lock():
206+
with self._locked_connection() as conn:
132207
if session_limit is None:
133208
# Fetch all items in chronological order
134209
cursor = conn.execute(
@@ -180,36 +255,8 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
180255
return
181256

182257
def _add_items_sync():
183-
conn = self._get_connection()
184-
185-
with self._lock if self._is_memory_db else threading.Lock():
186-
# Ensure session exists
187-
conn.execute(
188-
f"""
189-
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
190-
""",
191-
(self.session_id,),
192-
)
193-
194-
# Add items
195-
message_data = [(self.session_id, json.dumps(item)) for item in items]
196-
conn.executemany(
197-
f"""
198-
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
199-
""",
200-
message_data,
201-
)
202-
203-
# Update session timestamp
204-
conn.execute(
205-
f"""
206-
UPDATE {self.sessions_table}
207-
SET updated_at = CURRENT_TIMESTAMP
208-
WHERE session_id = ?
209-
""",
210-
(self.session_id,),
211-
)
212-
258+
with self._locked_connection() as conn:
259+
self._insert_items(conn, items)
213260
conn.commit()
214261

215262
await asyncio.to_thread(_add_items_sync)
@@ -222,8 +269,7 @@ async def pop_item(self) -> TResponseInputItem | None:
222269
"""
223270

224271
def _pop_item_sync():
225-
conn = self._get_connection()
226-
with self._lock if self._is_memory_db else threading.Lock():
272+
with self._locked_connection() as conn:
227273
# Use DELETE with RETURNING to atomically delete and return the most recent item
228274
cursor = conn.execute(
229275
f"""
@@ -259,8 +305,7 @@ async def clear_session(self) -> None:
259305
"""Clear all items for this session."""
260306

261307
def _clear_session_sync():
262-
conn = self._get_connection()
263-
with self._lock if self._is_memory_db else threading.Lock():
308+
with self._locked_connection() as conn:
264309
conn.execute(
265310
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
266311
(self.session_id,),
@@ -281,3 +326,6 @@ def close(self) -> None:
281326
else:
282327
if hasattr(self._local, "connection"):
283328
self._local.connection.close()
329+
if self._lock_path is not None and not self._lock_released:
330+
self._release_file_lock(self._lock_path)
331+
self._lock_released = True

src/agents/tracing/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"add_trace_processor",
4343
"agent_span",
4444
"custom_span",
45+
"flush_traces",
4546
"function_span",
4647
"generation_span",
4748
"get_current_span",
@@ -108,3 +109,14 @@ def set_tracing_export_api_key(api_key: str) -> None:
108109
Set the OpenAI API key for the backend exporter.
109110
"""
110111
default_exporter().set_api_key(api_key)
112+
113+
114+
def flush_traces() -> None:
115+
"""Force immediate export of buffered traces and spans.
116+
117+
The default ``BatchTraceProcessor`` already exports traces periodically in the
118+
background. Call this when a worker, background job, or request handler needs
119+
traces to be visible immediately after a unit of work finishes instead of
120+
waiting for the next scheduled flush.
121+
"""
122+
get_trace_provider().force_flush()

src/agents/tracing/processors.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def __init__(
491491
# We lazily start the background worker thread the first time a span/trace is queued.
492492
self._worker_thread: threading.Thread | None = None
493493
self._thread_start_lock = threading.Lock()
494+
self._export_lock = threading.Lock()
494495

495496
def _ensure_thread_started(self) -> None:
496497
# Fast path without holding the lock
@@ -571,25 +572,26 @@ def _export_batches(self, force: bool = False):
571572
"""Drains the queue and exports in batches. If force=True, export everything.
572573
Otherwise, export up to `max_batch_size` repeatedly until the queue is completely empty.
573574
"""
574-
while True:
575-
items_to_export: list[Span[Any] | Trace] = []
575+
with self._export_lock:
576+
while True:
577+
items_to_export: list[Span[Any] | Trace] = []
578+
579+
# Gather a batch of spans up to max_batch_size
580+
while not self._queue.empty() and (
581+
force or len(items_to_export) < self._max_batch_size
582+
):
583+
try:
584+
items_to_export.append(self._queue.get_nowait())
585+
except queue.Empty:
586+
# Another thread might have emptied the queue between checks
587+
break
576588

577-
# Gather a batch of spans up to max_batch_size
578-
while not self._queue.empty() and (
579-
force or len(items_to_export) < self._max_batch_size
580-
):
581-
try:
582-
items_to_export.append(self._queue.get_nowait())
583-
except queue.Empty:
584-
# Another thread might have emptied the queue between checks
589+
# If we collected nothing, we're done
590+
if not items_to_export:
585591
break
586592

587-
# If we collected nothing, we're done
588-
if not items_to_export:
589-
break
590-
591-
# Export the batch
592-
self._exporter.export(items_to_export)
593+
# Export the batch
594+
self._exporter.export(items_to_export)
593595

594596

595597
# Lazily initialized defaults to avoid creating network clients or threading

src/agents/tracing/provider.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,21 @@ def create_span(
188188
) -> Span[TSpanData]:
189189
"""Create a new span."""
190190

191-
@abstractmethod
191+
def force_flush(self) -> None:
192+
"""Force all registered processors to flush buffered traces/spans immediately.
193+
194+
The default implementation is a no-op so existing custom ``TraceProvider``
195+
implementations continue to work without adding this method.
196+
"""
197+
return None
198+
192199
def shutdown(self) -> None:
193-
"""Clean up any resources used by the provider."""
200+
"""Clean up any resources used by the provider.
201+
202+
The default implementation is a no-op so existing custom ``TraceProvider``
203+
implementations continue to work without adding this method.
204+
"""
205+
return None
194206

195207

196208
class DefaultTraceProvider(TraceProvider):
@@ -365,7 +377,19 @@ def create_span(
365377
trace_metadata=trace_metadata,
366378
)
367379

380+
def force_flush(self) -> None:
381+
"""Force all processors to flush their buffers immediately."""
382+
self._refresh_disabled_flag()
383+
if self._disabled:
384+
return
385+
386+
try:
387+
self._multi_processor.force_flush()
388+
except Exception as e:
389+
logger.error(f"Error flushing trace provider: {e}")
390+
368391
def shutdown(self) -> None:
392+
self._refresh_disabled_flag()
369393
if self._disabled:
370394
return
371395

0 commit comments

Comments
 (0)