diff --git a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java index 1293e7c2398..c55f24fdf35 100644 --- a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java +++ b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java @@ -1,9 +1,11 @@ package datadog.trace.instrumentation.java.lang.jdk21; import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; -import static datadog.trace.bootstrap.instrumentation.java.concurrent.AdviceUtils.capture; -import static datadog.trace.bootstrap.instrumentation.java.concurrent.AdviceUtils.endTaskScope; -import static datadog.trace.bootstrap.instrumentation.java.concurrent.AdviceUtils.startTaskScope; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.namedOneOf; +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activeSpan; +import static datadog.trace.bootstrap.instrumentation.java.concurrent.ConcurrentState.activateAndContinueContinuation; +import static datadog.trace.bootstrap.instrumentation.java.concurrent.ConcurrentState.captureContinuation; +import static datadog.trace.bootstrap.instrumentation.java.concurrent.ConcurrentState.closeScope; import static datadog.trace.bootstrap.instrumentation.java.lang.VirtualThreadHelper.AGENT_SCOPE_CLASS_NAME; import static datadog.trace.bootstrap.instrumentation.java.lang.VirtualThreadHelper.VIRTUAL_THREAD_CLASS_NAME; import static net.bytebuddy.matcher.ElementMatchers.isConstructor; @@ -16,7 +18,7 @@ import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.InstrumentationContext; import datadog.trace.bootstrap.instrumentation.api.AgentScope; -import datadog.trace.bootstrap.instrumentation.java.concurrent.State; +import datadog.trace.bootstrap.instrumentation.java.concurrent.ConcurrentState; import java.util.HashMap; import java.util.Map; import net.bytebuddy.asm.Advice; @@ -24,14 +26,32 @@ import net.bytebuddy.asm.Advice.OnMethodExit; /** - * Instruments {@code VirtualThread} to capture active state at creation, activate it on - * continuation mount, and close the scope from activation on continuation unmount. + * Instruments {@code VirtualThread} to capture active state at creation, activate it on mount, + * close the scope on unmount, and cancel the continuation on thread termination. + * + *

The lifecycle is as follows: + * + *

    + *
  1. {@code init()}: captures and holds a continuation from the active context (span due to + * legacy API). + *
  2. {@code mount()}: activates the held continuation, restoring the context on the current + * carrier thread. + *
  3. {@code unmount()}: closes the scope. The continuation survives as still hold. + *
  4. Steps 2-3 repeat on each park/unpark cycle, potentially on different carrier threads. + *
  5. {@code afterTerminate()} (for early versions of JDK 21 and 22 before GA), {@code afterDone} + * (for JDK 21 GA above): cancels the held continuation to let the context scope to be closed. + *
* *

The instrumentation uses two context stores. The first from {@link Runnable} (as {@code - * VirtualThread} inherits from {@link Runnable}) to store the captured {@link State} to restore - * later. It additionally stores the {@link AgentScope} to be able to close it later as activation / - * close is not done around the same method (so passing the scope from {@link OnMethodEnter} / - * {@link OnMethodExit} using advice return value is not possible). + * VirtualThread} inherits from {@link Runnable}) to store the captured {@link ConcurrentState} to + * restore later. It additionally stores the {@link AgentScope} to be able to close it later as + * activation / close is not done around the same method (so passing the scope from {@link + * OnMethodEnter} / {@link OnMethodExit} using advice return value is not possible). + * + *

{@link ConcurrentState} is used instead of {@code State} because virtual threads can mount and + * unmount multiple times across different carrier threads. The held continuation in {@link + * ConcurrentState} survives multiple activate/close cycles without being consumed, and is + * explicitly canceled on thread termination. * *

Instrumenting the internal {@code VirtualThread.runContinuation()} method does not work as the * current thread is still the carrier thread and not a virtual thread. Activating the state when on @@ -62,7 +82,7 @@ public boolean isEnabled() { @Override public Map contextStore() { Map contextStore = new HashMap<>(); - contextStore.put(Runnable.class.getName(), State.class.getName()); + contextStore.put(Runnable.class.getName(), ConcurrentState.class.getName()); contextStore.put(VIRTUAL_THREAD_CLASS_NAME, AGENT_SCOPE_CLASS_NAME); return contextStore; } @@ -72,36 +92,54 @@ public void methodAdvice(MethodTransformer transformer) { transformer.applyAdvice(isConstructor(), getClass().getName() + "$Construct"); transformer.applyAdvice(isMethod().and(named("mount")), getClass().getName() + "$Activate"); transformer.applyAdvice(isMethod().and(named("unmount")), getClass().getName() + "$Close"); + transformer.applyAdvice( + isMethod().and(namedOneOf("afterTerminate", "afterDone")), + getClass().getName() + "$Terminate"); } public static final class Construct { @OnMethodExit(suppress = Throwable.class) public static void captureScope(@Advice.This Object virtualThread) { - capture(InstrumentationContext.get(Runnable.class, State.class), (Runnable) virtualThread); + captureContinuation( + InstrumentationContext.get(Runnable.class, ConcurrentState.class), + (Runnable) virtualThread, + activeSpan()); } } public static final class Activate { @OnMethodExit(suppress = Throwable.class) public static void activate(@Advice.This Object virtualThread) { - ContextStore stateStore = - InstrumentationContext.get(Runnable.class, State.class); - ContextStore scopeStore = + AgentScope scope = + activateAndContinueContinuation( + InstrumentationContext.get(Runnable.class, ConcurrentState.class), + (Runnable) virtualThread); + ContextStore scopeStore = InstrumentationContext.get(VIRTUAL_THREAD_CLASS_NAME, AGENT_SCOPE_CLASS_NAME); - AgentScope agentScope = startTaskScope(stateStore, (Runnable) virtualThread); - scopeStore.put(virtualThread, agentScope); + scopeStore.put(virtualThread, scope); } } public static final class Close { @OnMethodEnter(suppress = Throwable.class) public static void close(@Advice.This Object virtualThread) { - ContextStore scopeStore = + ContextStore scopeStore = InstrumentationContext.get(VIRTUAL_THREAD_CLASS_NAME, AGENT_SCOPE_CLASS_NAME); - Object agentScope = scopeStore.get(virtualThread); - if (agentScope instanceof AgentScope) { - endTaskScope((AgentScope) agentScope); - } + AgentScope scope = scopeStore.remove(virtualThread); + closeScope( + InstrumentationContext.get(Runnable.class, ConcurrentState.class), + (Runnable) virtualThread, + scope, + null); + } + } + + public static final class Terminate { + @OnMethodEnter(suppress = Throwable.class) + public static void terminate(@Advice.This Object virtualThread) { + ConcurrentState.cancelAndClearContinuation( + InstrumentationContext.get(Runnable.class, ConcurrentState.class), + (Runnable) virtualThread); } } } diff --git a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java index b6359e826d0..7a067ca6825 100644 --- a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java +++ b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java @@ -11,6 +11,7 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +/** Test the {@code VirtualThread} and {@code Thread.Builder} API. */ public class VirtualThreadApiInstrumentationTest extends AbstractInstrumentationTest { @DisplayName("test Thread.Builder.OfVirtual.start()") diff --git a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadLifeCycleTest.java b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadLifeCycleTest.java new file mode 100644 index 00000000000..8a5b9008565 --- /dev/null +++ b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadLifeCycleTest.java @@ -0,0 +1,174 @@ +package testdog.trace.instrumentation.java.lang.jdk21; + +import static datadog.trace.agent.test.assertions.SpanMatcher.span; +import static datadog.trace.agent.test.assertions.TraceMatcher.trace; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import datadog.trace.agent.test.AbstractInstrumentationTest; +import datadog.trace.api.CorrelationIdentifier; +import datadog.trace.api.GlobalTracer; +import datadog.trace.api.Trace; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** Test context tracking through {@code VirtualThread} lifecycle - park/unpark (remount) cycles. */ +public class VirtualThreadLifeCycleTest extends AbstractInstrumentationTest { + private static final Duration TIMEOUT = Duration.ofSeconds(10); + + @DisplayName("test context restored after virtual thread remounts") + @Test + void testContextRestoredAfterVirtualThreadRemount() { + int remountCount = 5; + String[] spanId = new String[1]; + String[] spanIdBeforeUnmount = new String[1]; + String[] spanIdsAfterRemount = new String[remountCount]; + + new Runnable() { + @Override + @Trace(operationName = "parent") + public void run() { + spanId[0] = GlobalTracer.get().getSpanId(); + + Thread thread = + Thread.startVirtualThread( + () -> { + spanIdBeforeUnmount[0] = GlobalTracer.get().getSpanId(); + for (int remount = 0; remount < remountCount; remount++) { + tryUnmount(); + spanIdsAfterRemount[remount] = GlobalTracer.get().getSpanId(); + } + }); + try { + thread.join(TIMEOUT); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.run(); + + assertEquals( + spanId[0], + spanIdBeforeUnmount[0], + "context should be inherited from the parent execution unit"); + for (int i = 0; i < remountCount; i++) { + assertEquals( + spanId[0], + spanIdsAfterRemount[i], + "context should be restored after virtual thread remounts"); + } + + assertTraces(trace(span().root().operationName("parent"))); + } + + @DisplayName("test context restored as implicit parent span after remount") + @Test + void testContextRestoredAsImplicitParentSpanAfterRemount() { + new Runnable() { + @Override + @Trace(operationName = "parent") + public void run() { + Thread thread = + Thread.startVirtualThread( + () -> { + tryUnmount(); + // Runnable to create child span, not async related + new Runnable() { + @Override + @Trace(operationName = "child") + public void run() {} + }.run(); + }); + try { + thread.join(TIMEOUT); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + blockUntilChildSpansFinished(1); + } + }.run(); + + assertTraces( + trace( + span().root().operationName("parent"), + span().childOfPrevious().operationName("child"))); + } + + @DisplayName("test concurrent virtual threads with remount") + @Test + void testConcurrentVirtualThreadsWithRemount() { + int threadCount = 5; + String[] spanId = new String[1]; + String[] spanIdsAfterRemount = new String[threadCount]; + + new Runnable() { + @Override + @Trace(operationName = "parent") + public void run() { + spanId[0] = CorrelationIdentifier.getSpanId(); + + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + int index = i; + threads.add( + Thread.startVirtualThread( + () -> { + tryUnmount(); + spanIdsAfterRemount[index] = CorrelationIdentifier.getSpanId(); + })); + } + + for (Thread thread : threads) { + try { + thread.join(TIMEOUT); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + }.run(); + + for (int i = 0; i < threadCount; i++) { + assertEquals( + spanId[0], + spanIdsAfterRemount[i], + "context should be restored after virtual thread #" + i + "remounts"); + } + + assertTraces(trace(span().root().operationName("parent"))); + } + + @DisplayName("test no context virtual thread remount") + @Test + void testNoContextVirtualThreadRemount() throws InterruptedException { + AtomicReference spanIdBeforeUnmount = new AtomicReference<>(); + AtomicReference spanIdAfterRemount = new AtomicReference<>(); + + Thread.startVirtualThread( + () -> { + spanIdBeforeUnmount.set(CorrelationIdentifier.getSpanId()); + tryUnmount(); + spanIdAfterRemount.set(CorrelationIdentifier.getSpanId()); + }) + .join(TIMEOUT); + + assertEquals( + "0", spanIdBeforeUnmount.get(), "there should be no active context before unmount"); + assertEquals("0", spanIdAfterRemount.get(), "there should be no active context after remount"); + } + + private static void tryUnmount() { + try { + // Multiple sleeps to expect triggering repeated park/unpark cycles. + // This is not guaranteed to work, but there is no API to force mount/unmount. + for (int i = 0; i < 5; i++) { + Thread.sleep(10); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +}