Skip to content

Commit bfdd5b3

Browse files
temporal-spring-ai: add side-effect replay tests for chat, activity tools, and @SideEffectTool
Three new tests under src/test/.../replay/: - ChatModelSideEffectTest: register a ChatModel with an AtomicInteger counter. Run a workflow that makes one chat call, assert counter=1. Replay the captured history, assert counter still 1 — the activity result comes from history, not from re-invoking the ChatModel. - ActivityToolSideEffectTest: activity-backed @tool whose impl increments a counter. ToolCallingStubChatModel asks for the tool on the first call and returns final text on the second. Same assertion shape: counter=1 after run, counter=1 after replay. - SideEffectToolReplayTest: @SideEffectTool body increments a counter via a file-scope static. Workflow drives a tool call through ToolCallingStubChatModel. The assertion proves that Workflow.sideEffect's marker is what's consulted on replay rather than re-invoking the @tool method. MCP is intentionally omitted — spring-ai-mcp is compileOnly and adding it just for one test isn't worth the dep weight. MCP tool calls go through the same Temporal activity machinery as ChatModel, which ChatModelSideEffectTest already covers. I verified the SideEffectToolReplayTest catches a real regression by temporarily dropping the Workflow.sideEffect wrap in SideEffectToolCallback; the test correctly failed with `expected: <1> but was: <2>`. Restored before this commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7427cfc commit bfdd5b3

3 files changed

Lines changed: 405 additions & 0 deletions

File tree

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package io.temporal.springai.replay;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.temporal.activity.ActivityInterface;
6+
import io.temporal.activity.ActivityMethod;
7+
import io.temporal.activity.ActivityOptions;
8+
import io.temporal.client.WorkflowClient;
9+
import io.temporal.client.WorkflowOptions;
10+
import io.temporal.client.WorkflowStub;
11+
import io.temporal.common.WorkflowExecutionHistory;
12+
import io.temporal.springai.activity.ChatModelActivityImpl;
13+
import io.temporal.springai.chat.TemporalChatClient;
14+
import io.temporal.springai.model.ActivityChatModel;
15+
import io.temporal.testing.TestWorkflowEnvironment;
16+
import io.temporal.testing.WorkflowReplayer;
17+
import io.temporal.worker.Worker;
18+
import io.temporal.workflow.Workflow;
19+
import io.temporal.workflow.WorkflowInterface;
20+
import io.temporal.workflow.WorkflowMethod;
21+
import java.time.Duration;
22+
import java.util.List;
23+
import java.util.concurrent.atomic.AtomicInteger;
24+
import org.junit.jupiter.api.AfterEach;
25+
import org.junit.jupiter.api.BeforeEach;
26+
import org.junit.jupiter.api.Test;
27+
import org.springframework.ai.chat.client.ChatClient;
28+
import org.springframework.ai.chat.messages.AssistantMessage;
29+
import org.springframework.ai.chat.model.ChatModel;
30+
import org.springframework.ai.chat.model.ChatResponse;
31+
import org.springframework.ai.chat.model.Generation;
32+
import org.springframework.ai.chat.prompt.Prompt;
33+
import org.springframework.ai.tool.annotation.Tool;
34+
35+
/**
36+
* Asserts that a workflow replay does not re-invoke activity-backed tools. The {@link AddActivity}
37+
* impl holds a counter that increments on each tool call. After the initial run the counter is 1;
38+
* after replaying the captured history, it must still be 1 — activity results for the tool call
39+
* come from history, not from re-invoking the activity impl.
40+
*/
41+
class ActivityToolSideEffectTest {
42+
43+
private static final String TASK_QUEUE = "test-spring-ai-activity-tool-side-effect";
44+
45+
private TestWorkflowEnvironment testEnv;
46+
private WorkflowClient client;
47+
private AddActivityImpl addActivity;
48+
49+
@BeforeEach
50+
void setUp() {
51+
testEnv = TestWorkflowEnvironment.newInstance();
52+
client = testEnv.getWorkflowClient();
53+
addActivity = new AddActivityImpl();
54+
}
55+
56+
@AfterEach
57+
void tearDown() {
58+
testEnv.close();
59+
}
60+
61+
@Test
62+
void activityTool_notReInvokedOnReplay() throws Exception {
63+
Worker worker = testEnv.newWorker(TASK_QUEUE);
64+
worker.registerWorkflowImplementationTypes(ChatWithToolsWorkflowImpl.class);
65+
worker.registerActivitiesImplementations(
66+
new ChatModelActivityImpl(new ToolCallingStubChatModel()), addActivity);
67+
testEnv.start();
68+
69+
ChatWithToolsWorkflow workflow =
70+
client.newWorkflowStub(
71+
ChatWithToolsWorkflow.class,
72+
WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build());
73+
assertEquals("The answer is 5", workflow.chat("What is 2+3?"));
74+
assertEquals(
75+
1, addActivity.callCount.get(), "Tool activity should run once during the initial run");
76+
77+
WorkflowExecutionHistory history =
78+
client.fetchHistory(WorkflowStub.fromTyped(workflow).getExecution().getWorkflowId());
79+
WorkflowReplayer.replayWorkflowExecution(history, ChatWithToolsWorkflowImpl.class);
80+
81+
assertEquals(
82+
1,
83+
addActivity.callCount.get(),
84+
"Tool activity must not be re-invoked during replay — results come from history");
85+
}
86+
87+
@WorkflowInterface
88+
public interface ChatWithToolsWorkflow {
89+
@WorkflowMethod
90+
String chat(String message);
91+
}
92+
93+
@ActivityInterface
94+
public interface AddActivity {
95+
@Tool(description = "Add two numbers")
96+
@ActivityMethod
97+
int add(int a, int b);
98+
}
99+
100+
public static class AddActivityImpl implements AddActivity {
101+
final AtomicInteger callCount = new AtomicInteger(0);
102+
103+
@Override
104+
public int add(int a, int b) {
105+
callCount.incrementAndGet();
106+
return a + b;
107+
}
108+
}
109+
110+
public static class ChatWithToolsWorkflowImpl implements ChatWithToolsWorkflow {
111+
@Override
112+
public String chat(String message) {
113+
ActivityChatModel chatModel = ActivityChatModel.forDefault();
114+
AddActivity addTool =
115+
Workflow.newActivityStub(
116+
AddActivity.class,
117+
ActivityOptions.newBuilder().setStartToCloseTimeout(Duration.ofSeconds(30)).build());
118+
ChatClient chatClient = TemporalChatClient.builder(chatModel).defaultTools(addTool).build();
119+
return chatClient.prompt().user(message).call().content();
120+
}
121+
}
122+
123+
/** First call: request the "add" tool. Second call: return final text. */
124+
private static class ToolCallingStubChatModel implements ChatModel {
125+
private final AtomicInteger callCount = new AtomicInteger(0);
126+
127+
@Override
128+
public ChatResponse call(Prompt prompt) {
129+
if (callCount.getAndIncrement() == 0) {
130+
AssistantMessage toolRequest =
131+
AssistantMessage.builder()
132+
.content("")
133+
.toolCalls(
134+
List.of(
135+
new AssistantMessage.ToolCall(
136+
"call_1", "function", "add", "{\"a\":2,\"b\":3}")))
137+
.build();
138+
return ChatResponse.builder().generations(List.of(new Generation(toolRequest))).build();
139+
}
140+
return ChatResponse.builder()
141+
.generations(List.of(new Generation(new AssistantMessage("The answer is 5"))))
142+
.build();
143+
}
144+
145+
@Override
146+
public reactor.core.publisher.Flux<ChatResponse> stream(Prompt prompt) {
147+
throw new UnsupportedOperationException();
148+
}
149+
}
150+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package io.temporal.springai.replay;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.temporal.client.WorkflowClient;
6+
import io.temporal.client.WorkflowOptions;
7+
import io.temporal.client.WorkflowStub;
8+
import io.temporal.common.WorkflowExecutionHistory;
9+
import io.temporal.springai.activity.ChatModelActivityImpl;
10+
import io.temporal.springai.model.ActivityChatModel;
11+
import io.temporal.testing.TestWorkflowEnvironment;
12+
import io.temporal.testing.WorkflowReplayer;
13+
import io.temporal.worker.Worker;
14+
import io.temporal.workflow.WorkflowInterface;
15+
import io.temporal.workflow.WorkflowMethod;
16+
import java.util.List;
17+
import java.util.concurrent.atomic.AtomicInteger;
18+
import org.junit.jupiter.api.AfterEach;
19+
import org.junit.jupiter.api.BeforeEach;
20+
import org.junit.jupiter.api.Test;
21+
import org.springframework.ai.chat.client.ChatClient;
22+
import org.springframework.ai.chat.messages.AssistantMessage;
23+
import org.springframework.ai.chat.model.ChatModel;
24+
import org.springframework.ai.chat.model.ChatResponse;
25+
import org.springframework.ai.chat.model.Generation;
26+
import org.springframework.ai.chat.prompt.Prompt;
27+
28+
/**
29+
* Asserts that a workflow replay does not re-invoke the underlying {@link ChatModel}. The counter
30+
* lives on the activity's backing ChatModel, which is only reached when the {@code CallChatModel}
31+
* activity is scheduled by the workflow. On replay, the activity result is fetched from history;
32+
* the impl is not re-invoked. If we ever regress by dropping that guarantee — say by adding an
33+
* in-workflow cache that falls back to invoking the model directly — the counter will advance to 2
34+
* and this test will fail.
35+
*/
36+
class ChatModelSideEffectTest {
37+
38+
private static final String TASK_QUEUE = "test-spring-ai-chat-side-effect";
39+
40+
private TestWorkflowEnvironment testEnv;
41+
private WorkflowClient client;
42+
private CountingChatModel model;
43+
44+
@BeforeEach
45+
void setUp() {
46+
testEnv = TestWorkflowEnvironment.newInstance();
47+
client = testEnv.getWorkflowClient();
48+
model = new CountingChatModel("pong");
49+
}
50+
51+
@AfterEach
52+
void tearDown() {
53+
testEnv.close();
54+
}
55+
56+
@Test
57+
void chatModel_notReInvokedOnReplay() throws Exception {
58+
Worker worker = testEnv.newWorker(TASK_QUEUE);
59+
worker.registerWorkflowImplementationTypes(ChatWorkflowImpl.class);
60+
worker.registerActivitiesImplementations(new ChatModelActivityImpl(model));
61+
testEnv.start();
62+
63+
ChatWorkflow workflow =
64+
client.newWorkflowStub(
65+
ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build());
66+
assertEquals("pong", workflow.chat("ping"));
67+
assertEquals(
68+
1, model.callCount.get(), "ChatModel should be called once during the initial run");
69+
70+
WorkflowExecutionHistory history =
71+
client.fetchHistory(WorkflowStub.fromTyped(workflow).getExecution().getWorkflowId());
72+
WorkflowReplayer.replayWorkflowExecution(history, ChatWorkflowImpl.class);
73+
74+
assertEquals(
75+
1,
76+
model.callCount.get(),
77+
"ChatModel must not be re-invoked during replay — activity results come from history");
78+
}
79+
80+
@WorkflowInterface
81+
public interface ChatWorkflow {
82+
@WorkflowMethod
83+
String chat(String message);
84+
}
85+
86+
public static class ChatWorkflowImpl implements ChatWorkflow {
87+
@Override
88+
public String chat(String message) {
89+
ActivityChatModel chatModel = ActivityChatModel.forDefault();
90+
ChatClient chatClient = ChatClient.builder(chatModel).build();
91+
return chatClient.prompt().user(message).call().content();
92+
}
93+
}
94+
95+
private static class CountingChatModel implements ChatModel {
96+
final AtomicInteger callCount = new AtomicInteger(0);
97+
private final String response;
98+
99+
CountingChatModel(String response) {
100+
this.response = response;
101+
}
102+
103+
@Override
104+
public ChatResponse call(Prompt prompt) {
105+
callCount.incrementAndGet();
106+
return ChatResponse.builder()
107+
.generations(List.of(new Generation(new AssistantMessage(response))))
108+
.build();
109+
}
110+
111+
@Override
112+
public reactor.core.publisher.Flux<ChatResponse> stream(Prompt prompt) {
113+
throw new UnsupportedOperationException();
114+
}
115+
}
116+
}

0 commit comments

Comments
 (0)