|
| 1 | +package io.temporal.springai; |
| 2 | + |
| 3 | +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; |
| 4 | +import static org.junit.jupiter.api.Assertions.assertEquals; |
| 5 | +import static org.junit.jupiter.api.Assertions.assertThrows; |
| 6 | +import static org.junit.jupiter.api.Assertions.assertTrue; |
| 7 | + |
| 8 | +import io.temporal.client.WorkflowClient; |
| 9 | +import io.temporal.client.WorkflowException; |
| 10 | +import io.temporal.client.WorkflowOptions; |
| 11 | +import io.temporal.failure.ApplicationFailure; |
| 12 | +import io.temporal.springai.activity.ChatModelActivityImpl; |
| 13 | +import io.temporal.springai.model.ActivityChatModel; |
| 14 | +import io.temporal.springai.model.ChatModelTypes; |
| 15 | +import io.temporal.testing.TestWorkflowEnvironment; |
| 16 | +import io.temporal.worker.Worker; |
| 17 | +import io.temporal.workflow.WorkflowInterface; |
| 18 | +import io.temporal.workflow.WorkflowMethod; |
| 19 | +import java.net.URI; |
| 20 | +import java.util.List; |
| 21 | +import org.junit.jupiter.api.AfterEach; |
| 22 | +import org.junit.jupiter.api.BeforeEach; |
| 23 | +import org.junit.jupiter.api.Test; |
| 24 | +import org.springframework.ai.chat.messages.AssistantMessage; |
| 25 | +import org.springframework.ai.chat.messages.UserMessage; |
| 26 | +import org.springframework.ai.chat.model.ChatModel; |
| 27 | +import org.springframework.ai.chat.model.ChatResponse; |
| 28 | +import org.springframework.ai.chat.model.Generation; |
| 29 | +import org.springframework.ai.chat.prompt.Prompt; |
| 30 | +import org.springframework.ai.content.Media; |
| 31 | +import org.springframework.core.io.ByteArrayResource; |
| 32 | +import org.springframework.util.MimeType; |
| 33 | +import org.springframework.util.MimeTypeUtils; |
| 34 | + |
| 35 | +/** |
| 36 | + * Unit tests around {@link ChatModelTypes#checkMediaSize(byte[])} plus integration-style tests |
| 37 | + * against a live TestWorkflowEnvironment to make sure the guard fires on both the inbound (workflow |
| 38 | + * → activity) and outbound (activity → workflow) conversion paths. |
| 39 | + */ |
| 40 | +class MediaSizeGuardTest { |
| 41 | + |
| 42 | + private static final String TASK_QUEUE = "test-spring-ai-media-size-guard"; |
| 43 | + |
| 44 | + private TestWorkflowEnvironment testEnv; |
| 45 | + private WorkflowClient client; |
| 46 | + |
| 47 | + @BeforeEach |
| 48 | + void setUp() { |
| 49 | + testEnv = TestWorkflowEnvironment.newInstance(); |
| 50 | + client = testEnv.getWorkflowClient(); |
| 51 | + } |
| 52 | + |
| 53 | + @AfterEach |
| 54 | + void tearDown() { |
| 55 | + testEnv.close(); |
| 56 | + } |
| 57 | + |
| 58 | + @Test |
| 59 | + void checkMediaSize_smallPayload_passes() { |
| 60 | + byte[] small = new byte[500 * 1024]; // 500 KiB, well under 1 MiB |
| 61 | + assertDoesNotThrow(() -> ChatModelTypes.checkMediaSize(small)); |
| 62 | + } |
| 63 | + |
| 64 | + @Test |
| 65 | + void checkMediaSize_oversizedPayload_throwsNonRetryableApplicationFailure() { |
| 66 | + byte[] big = new byte[(int) ChatModelTypes.MAX_MEDIA_BYTES_IN_HISTORY + 1]; |
| 67 | + ApplicationFailure ex = |
| 68 | + assertThrows(ApplicationFailure.class, () -> ChatModelTypes.checkMediaSize(big)); |
| 69 | + assertTrue(ex.isNonRetryable(), "guard must throw a non-retryable ApplicationFailure"); |
| 70 | + assertEquals(ChatModelTypes.MEDIA_SIZE_EXCEEDED_FAILURE_TYPE, ex.getType()); |
| 71 | + String msg = ex.getOriginalMessage(); |
| 72 | + assertTrue(msg.contains("URI"), "message should point at the URI alternative: " + msg); |
| 73 | + assertTrue( |
| 74 | + msg.contains("io.temporal.springai.maxMediaBytes"), |
| 75 | + "message should mention the override system property: " + msg); |
| 76 | + } |
| 77 | + |
| 78 | + @Test |
| 79 | + void checkMediaSize_null_passes() { |
| 80 | + assertDoesNotThrow(() -> ChatModelTypes.checkMediaSize(null)); |
| 81 | + } |
| 82 | + |
| 83 | + @Test |
| 84 | + void inboundPath_oversizedUserMessageMedia_failsTheWorkflow() { |
| 85 | + // Workflow → activity direction: the workflow builds a Prompt with a huge byte[] media, |
| 86 | + // ActivityChatModel.createActivityInput calls toMediaContent → checkMediaSize throws. |
| 87 | + Worker worker = testEnv.newWorker(TASK_QUEUE); |
| 88 | + worker.registerWorkflowImplementationTypes(BigInboundMediaWorkflowImpl.class); |
| 89 | + worker.registerActivitiesImplementations(new ChatModelActivityImpl(new StubChatModel())); |
| 90 | + testEnv.start(); |
| 91 | + |
| 92 | + ChatWorkflow workflow = |
| 93 | + client.newWorkflowStub( |
| 94 | + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); |
| 95 | + WorkflowException ex = assertThrows(WorkflowException.class, () -> workflow.chat("hi")); |
| 96 | + String message = rootMessage(ex); |
| 97 | + assertTrue( |
| 98 | + message.contains(ChatModelTypes.MEDIA_SIZE_EXCEEDED_FAILURE_TYPE) |
| 99 | + || message.contains("-byte limit"), |
| 100 | + "expected size-guard failure, got: " + message); |
| 101 | + } |
| 102 | + |
| 103 | + @Test |
| 104 | + void inboundPath_smallMedia_passes() { |
| 105 | + Worker worker = testEnv.newWorker(TASK_QUEUE); |
| 106 | + worker.registerWorkflowImplementationTypes(SmallInboundMediaWorkflowImpl.class); |
| 107 | + worker.registerActivitiesImplementations(new ChatModelActivityImpl(new StubChatModel())); |
| 108 | + testEnv.start(); |
| 109 | + |
| 110 | + ChatWorkflow workflow = |
| 111 | + client.newWorkflowStub( |
| 112 | + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); |
| 113 | + assertEquals("pong", workflow.chat("hi")); |
| 114 | + } |
| 115 | + |
| 116 | + @Test |
| 117 | + void inboundPath_uriMedia_passes_regardlessOfSize() { |
| 118 | + // URI-based media is not subject to the byte[] guard — bytes stay out of workflow history. |
| 119 | + Worker worker = testEnv.newWorker(TASK_QUEUE); |
| 120 | + worker.registerWorkflowImplementationTypes(UriMediaWorkflowImpl.class); |
| 121 | + worker.registerActivitiesImplementations(new ChatModelActivityImpl(new StubChatModel())); |
| 122 | + testEnv.start(); |
| 123 | + |
| 124 | + ChatWorkflow workflow = |
| 125 | + client.newWorkflowStub( |
| 126 | + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); |
| 127 | + assertEquals("pong", workflow.chat("hi")); |
| 128 | + } |
| 129 | + |
| 130 | + @Test |
| 131 | + void outboundPath_assistantEchoesOversizedMedia_failsTheActivity() { |
| 132 | + // Activity → workflow direction: the stub ChatModel returns an assistant message with a |
| 133 | + // huge byte[] media, ChatModelActivityImpl.fromMedia → checkMediaSize throws. |
| 134 | + Worker worker = testEnv.newWorker(TASK_QUEUE); |
| 135 | + worker.registerWorkflowImplementationTypes(EchoMediaWorkflowImpl.class); |
| 136 | + worker.registerActivitiesImplementations( |
| 137 | + new ChatModelActivityImpl(new BigOutboundMediaChatModel())); |
| 138 | + testEnv.start(); |
| 139 | + |
| 140 | + ChatWorkflow workflow = |
| 141 | + client.newWorkflowStub( |
| 142 | + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); |
| 143 | + WorkflowException ex = assertThrows(WorkflowException.class, () -> workflow.chat("hi")); |
| 144 | + String message = rootMessage(ex); |
| 145 | + assertTrue( |
| 146 | + message.contains("exceeds the") && message.contains("-byte limit"), |
| 147 | + "expected size-guard failure on return path, got: " + message); |
| 148 | + } |
| 149 | + |
| 150 | + private static String rootMessage(Throwable t) { |
| 151 | + Throwable cur = t; |
| 152 | + while (cur.getCause() != null) { |
| 153 | + cur = cur.getCause(); |
| 154 | + } |
| 155 | + return cur.getMessage() == null ? "" : cur.getMessage(); |
| 156 | + } |
| 157 | + |
| 158 | + @WorkflowInterface |
| 159 | + public interface ChatWorkflow { |
| 160 | + @WorkflowMethod |
| 161 | + String chat(String message); |
| 162 | + } |
| 163 | + |
| 164 | + public static class BigInboundMediaWorkflowImpl implements ChatWorkflow { |
| 165 | + @Override |
| 166 | + public String chat(String message) { |
| 167 | + byte[] big = new byte[(int) ChatModelTypes.MAX_MEDIA_BYTES_IN_HISTORY + 1]; |
| 168 | + UserMessage userMessage = |
| 169 | + UserMessage.builder() |
| 170 | + .text(message) |
| 171 | + .media(List.of(new Media(MimeTypeUtils.IMAGE_PNG, new ByteArrayResource(big)))) |
| 172 | + .build(); |
| 173 | + ActivityChatModel chatModel = ActivityChatModel.forDefault(); |
| 174 | + return chatModel.call(new Prompt(List.of(userMessage))).getResult().getOutput().getText(); |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + public static class SmallInboundMediaWorkflowImpl implements ChatWorkflow { |
| 179 | + @Override |
| 180 | + public String chat(String message) { |
| 181 | + byte[] small = new byte[16 * 1024]; // 16 KiB |
| 182 | + UserMessage userMessage = |
| 183 | + UserMessage.builder() |
| 184 | + .text(message) |
| 185 | + .media(List.of(new Media(MimeTypeUtils.IMAGE_PNG, new ByteArrayResource(small)))) |
| 186 | + .build(); |
| 187 | + ActivityChatModel chatModel = ActivityChatModel.forDefault(); |
| 188 | + return chatModel.call(new Prompt(List.of(userMessage))).getResult().getOutput().getText(); |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + public static class UriMediaWorkflowImpl implements ChatWorkflow { |
| 193 | + @Override |
| 194 | + public String chat(String message) { |
| 195 | + UserMessage userMessage = |
| 196 | + UserMessage.builder() |
| 197 | + .text(message) |
| 198 | + .media( |
| 199 | + List.of( |
| 200 | + new Media( |
| 201 | + MimeTypeUtils.IMAGE_PNG, URI.create("https://cdn.example.com/huge.png")))) |
| 202 | + .build(); |
| 203 | + ActivityChatModel chatModel = ActivityChatModel.forDefault(); |
| 204 | + return chatModel.call(new Prompt(List.of(userMessage))).getResult().getOutput().getText(); |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + public static class EchoMediaWorkflowImpl implements ChatWorkflow { |
| 209 | + @Override |
| 210 | + public String chat(String message) { |
| 211 | + ActivityChatModel chatModel = ActivityChatModel.forDefault(); |
| 212 | + return chatModel.call(new Prompt(message)).getResult().getOutput().getText(); |
| 213 | + } |
| 214 | + } |
| 215 | + |
| 216 | + /** Returns "pong" — used to verify non-failing paths. */ |
| 217 | + private static class StubChatModel implements ChatModel { |
| 218 | + @Override |
| 219 | + public ChatResponse call(Prompt prompt) { |
| 220 | + return ChatResponse.builder() |
| 221 | + .generations(List.of(new Generation(new AssistantMessage("pong")))) |
| 222 | + .build(); |
| 223 | + } |
| 224 | + |
| 225 | + @Override |
| 226 | + public reactor.core.publisher.Flux<ChatResponse> stream(Prompt prompt) { |
| 227 | + throw new UnsupportedOperationException(); |
| 228 | + } |
| 229 | + } |
| 230 | + |
| 231 | + /** Returns an assistant message carrying a huge byte[] media, to trip the outbound guard. */ |
| 232 | + private static class BigOutboundMediaChatModel implements ChatModel { |
| 233 | + @Override |
| 234 | + public ChatResponse call(Prompt prompt) { |
| 235 | + byte[] big = new byte[(int) ChatModelTypes.MAX_MEDIA_BYTES_IN_HISTORY + 1]; |
| 236 | + AssistantMessage assistant = |
| 237 | + AssistantMessage.builder() |
| 238 | + .content("") |
| 239 | + .media(List.of(new Media(MimeType.valueOf("image/png"), new ByteArrayResource(big)))) |
| 240 | + .build(); |
| 241 | + return ChatResponse.builder().generations(List.of(new Generation(assistant))).build(); |
| 242 | + } |
| 243 | + |
| 244 | + @Override |
| 245 | + public reactor.core.publisher.Flux<ChatResponse> stream(Prompt prompt) { |
| 246 | + throw new UnsupportedOperationException(); |
| 247 | + } |
| 248 | + } |
| 249 | +} |
0 commit comments