Skip to content

Commit 2119851

Browse files
th0br0copybara-github
authored andcommitted
Support async methods in scoped_trace in googletest Rust.
This change adds support for async methods in `scoped_trace` by using a runtime-agnostic `InstrumentedFuture` that manages the thread-local trace stack across yield points. PiperOrigin-RevId: 914068550
1 parent e489b8c commit 2119851

7 files changed

Lines changed: 227 additions & 16 deletions

File tree

googletest/src/internal/scoped_trace.rs

Lines changed: 104 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
use std::cell::RefCell;
1616
use std::sync::atomic::{AtomicUsize, Ordering};
17+
use std::sync::Arc;
1718

1819
static NEXT_TRACE_ID: AtomicUsize = AtomicUsize::new(0);
1920

@@ -26,19 +27,22 @@ pub struct TraceInfo {
2627
pub message: String,
2728
}
2829

30+
struct TraceNode {
31+
info: TraceInfo,
32+
parent: Option<Arc<TraceNode>>,
33+
}
34+
2935
thread_local! {
30-
static TRACE_STACK: RefCell<Vec<TraceInfo>> = const { RefCell::new(Vec::new()) };
36+
static TRACE_STACK: RefCell<Option<Arc<TraceNode>>> = const { RefCell::new(None) };
3137
}
3238

3339
/// RAII guard to manage the push and pop of trace information.
3440
///
35-
/// This struct is `!Send` and `!Sync` to prevent it from being held across
36-
/// `.await` points in async tests, which would cause incorrect trace tracking
37-
/// if the task moves between threads.
41+
/// This struct is `Send` to allow it to be used in async tests, but it should
42+
/// not be held across `.await` points directly unless the future is instrumented.
3843
#[doc(hidden)]
3944
pub struct ScopedTraceGuard {
4045
id: usize,
41-
_phantom: std::marker::PhantomData<*mut ()>,
4246
}
4347

4448
impl ScopedTraceGuard {
@@ -47,32 +51,119 @@ impl ScopedTraceGuard {
4751
pub fn new(message: String) -> Self {
4852
let caller = std::panic::Location::caller();
4953
let id = NEXT_TRACE_ID.fetch_add(1, Ordering::Relaxed);
54+
let info = TraceInfo { id, file: caller.file(), line: caller.line(), message };
55+
5056
TRACE_STACK.with(|stack| {
51-
// Use try_borrow_mut to avoid double panic if called during unwinding.
5257
if let Ok(mut s) = stack.try_borrow_mut() {
53-
s.push(TraceInfo { id, file: caller.file(), line: caller.line(), message });
58+
let prev = s.clone();
59+
*s = Some(Arc::new(TraceNode { info, parent: prev }));
5460
}
5561
});
56-
Self { id, _phantom: std::marker::PhantomData }
62+
Self { id }
5763
}
5864
}
5965

6066
impl Drop for ScopedTraceGuard {
6167
fn drop(&mut self) {
6268
TRACE_STACK.with(|stack| {
63-
// Use try_borrow_mut to avoid double panic if called during unwinding.
6469
if let Ok(mut s) = stack.try_borrow_mut() {
65-
if let Some(pos) = s.iter().rposition(|t| t.id == self.id) {
66-
s.remove(pos);
67-
}
70+
*s = remove_from_list(s.clone(), self.id);
6871
}
6972
});
7073
}
7174
}
7275

76+
/// Removes a node with the given `id` from the trace stack list.
77+
///
78+
/// Because the trace stack is structured as a persistent, reverse-linked list
79+
/// pointing to parent nodes, removing a node from the middle involves rebuilding
80+
/// the chain from the removed node up to the current head.
81+
fn remove_from_list(head: Option<Arc<TraceNode>>, id: usize) -> Option<Arc<TraceNode>> {
82+
let head = head?;
83+
if head.info.id == id {
84+
return head.parent.clone();
85+
}
86+
87+
let mut current = head.parent.clone();
88+
let mut nodes_to_recreate = vec![head.info.clone()];
89+
90+
while let Some(node) = current {
91+
if node.info.id == id {
92+
let mut new_head = node.parent.clone();
93+
for info in nodes_to_recreate.into_iter().rev() {
94+
new_head = Some(Arc::new(TraceNode { info, parent: new_head }));
95+
}
96+
return new_head;
97+
}
98+
nodes_to_recreate.push(node.info.clone());
99+
current = node.parent.clone();
100+
}
101+
Some(head)
102+
}
103+
73104
/// Retrieves a clone of the current thread's trace stack.
74105
pub fn get_scoped_traces() -> Vec<TraceInfo> {
75-
TRACE_STACK.with(|stack| stack.try_borrow().map(|s| s.clone()).unwrap_or_default())
106+
let mut traces = Vec::new();
107+
let mut current = TRACE_STACK.with(|stack| stack.try_borrow().ok().and_then(|s| s.clone()));
108+
while let Some(node) = current {
109+
traces.push(node.info.clone());
110+
current = node.parent.clone();
111+
}
112+
traces.reverse();
113+
traces
114+
}
115+
116+
/// A future that instruments another future with a set of traces.
117+
pub struct InstrumentedFuture<F> {
118+
inner: F,
119+
traces: Option<Arc<TraceNode>>,
120+
}
121+
122+
impl<F> InstrumentedFuture<F> {
123+
pub fn new(inner: F) -> Self {
124+
Self { inner, traces: None }
125+
}
126+
}
127+
128+
impl<F: std::future::Future> std::future::Future for InstrumentedFuture<F> {
129+
type Output = F::Output;
130+
131+
fn poll(
132+
self: std::pin::Pin<&mut Self>,
133+
cx: &mut std::task::Context<'_>,
134+
) -> std::task::Poll<Self::Output> {
135+
// SAFETY: `InstrumentedFuture` provides structural pinning. If `self` is
136+
// pinned, it is safe to pin the `inner` future.
137+
let this = unsafe { self.get_unchecked_mut() };
138+
139+
struct TraceSwapGuard<'a>(&'a mut Option<Arc<TraceNode>>);
140+
impl<'a> TraceSwapGuard<'a> {
141+
fn new(traces: &'a mut Option<Arc<TraceNode>>) -> Self {
142+
TRACE_STACK.with(|stack| {
143+
if let Ok(mut s) = stack.try_borrow_mut() {
144+
std::mem::swap(&mut *s, traces);
145+
}
146+
});
147+
Self(traces)
148+
}
149+
}
150+
impl<'a> Drop for TraceSwapGuard<'a> {
151+
fn drop(&mut self) {
152+
TRACE_STACK.with(|stack| {
153+
if let Ok(mut s) = stack.try_borrow_mut() {
154+
std::mem::swap(&mut *s, self.0);
155+
}
156+
});
157+
}
158+
}
159+
160+
// Swap traces into thread-local, ensuring they are swapped back on return/panic.
161+
let _guard = TraceSwapGuard::new(&mut this.traces);
162+
163+
// Poll inner future
164+
// SAFETY: As explained above, `this.inner` is properly pinned.
165+
unsafe { std::pin::Pin::new_unchecked(&mut this.inner) }.poll(cx)
166+
}
76167
}
77168

78169
// Test-only state and helpers, hidden from production API.

googletest/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ macro_rules! scoped_trace {
8181
pub use googletest_macro::gtest;
8282
pub use googletest_macro::test;
8383

84+
pub use internal::scoped_trace::InstrumentedFuture;
8485
use internal::test_outcome::{TestAssertionFailure, TestOutcome};
8586

8687
/// A `Result` whose `Err` variant indicates a test failure.

googletest_macro/src/lib.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,14 @@ pub fn gtest(
136136
let mut invocation = quote! {
137137
(#maybe_async move || -> #result_type {
138138
#block
139-
})() #maybe_await
139+
})()
140140
};
141141
if output_type.is_none() {
142-
invocation = quote! { { let () = #invocation; googletest::Result::<()>::Ok(()) } };
142+
if is_async {
143+
invocation = quote! { async move { let () = #invocation.await; googletest::Result::<()>::Ok(()) } };
144+
} else {
145+
invocation = quote! { { let () = #invocation; googletest::Result::<()>::Ok(()) } };
146+
}
143147
}
144148
invocation
145149
} else {
@@ -155,9 +159,19 @@ pub fn gtest(
155159
(#maybe_async move || -> #result_type {
156160
#sig { #block }
157161
#closure_body
158-
})() #maybe_await
162+
})()
159163
}
160164
};
165+
166+
let invocation = if is_async {
167+
quote! {
168+
::googletest::internal::scoped_trace::InstrumentedFuture::new(
169+
#invocation
170+
).await
171+
}
172+
} else {
173+
quote! { #invocation }
174+
};
161175
if !attrs.iter().any(is_test_attribute) && !is_rstest_enabled {
162176
let test_attr: Attribute = parse_quote! {
163177
#[::core::prelude::v1::test]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Google Test trace:
2+
async_scoped_trace_panic_test.rs:4: Sync trace message
3+
async_scoped_trace_panic_test.rs:14: Outer async trace message
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
BINARY_PATH="$1"
3+
4+
if [ -z "$BINARY_PATH" ]; then
5+
echo "Usage: $0 <binary_path>"
6+
exit 1
7+
fi
8+
9+
# Run the binary and capture output.
10+
set +e
11+
OUTPUT=$("${BINARY_PATH}" 2>&1)
12+
STATUS=$?
13+
set -e
14+
15+
# 101 is the standard Rust panic exit code.
16+
if [[ ${STATUS} -ne 101 ]]; then
17+
echo "Expected panic (exit code 101), but script exited with ${STATUS}."
18+
echo "Output:"
19+
echo "${OUTPUT}"
20+
exit 1
21+
fi
22+
23+
echo "${OUTPUT}" | \
24+
sed -n -e '/^Google Test trace:/p' -e '/^ .*trace/p' | \
25+
sed 's#^.*async_scoped_trace_panic_test.rs# async_scoped_trace_panic_test.rs#g'
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use googletest::prelude::*;
2+
use googletest::InstrumentedFuture;
3+
fn sync_function_with_trace() {
4+
scoped_trace!("Sync trace message");
5+
panic!("Intentional panic in sync function");
6+
}
7+
8+
#[tokio::main]
9+
async fn main() {
10+
// Initialize Google Test to install panic hook
11+
googletest::internal::test_outcome::TestOutcome::init_current_test_outcome();
12+
13+
let fut = async {
14+
scoped_trace!("Outer async trace message");
15+
16+
// Yield to verify async traces are preserved across yields
17+
tokio::task::yield_now().await;
18+
19+
// Call sync function that adds a trace and panics
20+
sync_function_with_trace();
21+
};
22+
23+
// Wrap in InstrumentedFuture to preserve traces across yields
24+
let fut = InstrumentedFuture::new(fut);
25+
26+
fut.await;
27+
}

integration_tests/src/async_test_with_expect_that.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,54 @@ mod tests {
7575
f.0 = 6;
7676
Ok(())
7777
}
78+
79+
#[gtest]
80+
#[tokio::test]
81+
async fn async_test_with_scoped_trace() -> Result<()> {
82+
scoped_trace!("Outer trace");
83+
brief_sleep().await;
84+
{
85+
scoped_trace!("Inner trace");
86+
brief_sleep().await;
87+
let traces = googletest::internal::scoped_trace::get_scoped_traces();
88+
verify_eq!(traces.len(), 2)?;
89+
verify_eq!(traces[0].message, "Outer trace")?;
90+
verify_eq!(traces[1].message, "Inner trace")?;
91+
}
92+
brief_sleep().await;
93+
let traces = googletest::internal::scoped_trace::get_scoped_traces();
94+
verify_eq!(traces.len(), 1)?;
95+
verify_eq!(traces[0].message, "Outer trace")?;
96+
Ok(())
97+
}
98+
99+
fn sync_subroutine_with_trace() {
100+
scoped_trace!("Sync subroutine trace");
101+
let traces = googletest::internal::scoped_trace::get_scoped_traces();
102+
assert!(traces.iter().any(|t| t.message == "Sync subroutine trace"));
103+
}
104+
105+
#[gtest]
106+
#[tokio::test]
107+
async fn async_test_mixed_scoped_trace() -> Result<()> {
108+
scoped_trace!("Outer async trace");
109+
brief_sleep().await;
110+
111+
sync_subroutine_with_trace();
112+
113+
{
114+
scoped_trace!("Inner async trace");
115+
brief_sleep().await;
116+
let traces = googletest::internal::scoped_trace::get_scoped_traces();
117+
verify_eq!(traces.len(), 2)?;
118+
verify_eq!(traces[0].message, "Outer async trace")?;
119+
verify_eq!(traces[1].message, "Inner async trace")?;
120+
}
121+
122+
brief_sleep().await;
123+
let traces = googletest::internal::scoped_trace::get_scoped_traces();
124+
verify_eq!(traces.len(), 1)?;
125+
verify_eq!(traces[0].message, "Outer async trace")?;
126+
Ok(())
127+
}
78128
}

0 commit comments

Comments
 (0)