diff --git a/temporal-spring-ai/README.md b/temporal-spring-ai/README.md index a19e96a4f..7f6b2405d 100644 --- a/temporal-spring-ai/README.md +++ b/temporal-spring-ai/README.md @@ -114,6 +114,22 @@ public class MyTools { Auto-detected and executed as Nexus operations, similar to activity stubs. +## Media in messages + +If you attach media (images, audio, etc.) to a `UserMessage` or an `AssistantMessage`, prefer passing it by URI rather than raw bytes: + +```java +// Good — only the URL crosses the activity boundary. +Media image = new Media(MimeTypeUtils.IMAGE_PNG, URI.create("https://cdn.example.com/pic.png")); + +// Works, but size-limited — see below. +Media image = new Media(MimeTypeUtils.IMAGE_PNG, new ByteArrayResource(bytes)); +``` + +Raw `byte[]` media gets serialized into every chat activity's input *and* result payload, which end up inside Temporal workflow history events. Server-side history events have a fixed 2 MiB size limit; to leave headroom for messages, tool definitions, and options, the plugin enforces a **1 MiB default cap** on inline media bytes and fails fast with an `IllegalArgumentException` pointing you at the URI alternative. + +Override the cap by setting the system property `io.temporal.springai.maxMediaBytes` before your worker starts (pass a positive integer; `0` disables the check). For anything larger than a small thumbnail, the URI route is the right answer — have an activity write the bytes to blob storage, then pass only the URL into the conversation. + ## Optional Integrations Auto-configured when their dependencies are on the classpath: diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java index 4eca09e67..15e7ffefd 100644 --- a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java @@ -239,6 +239,7 @@ private ChatModelTypes.MediaContent fromMedia(Media media) { if (media.getData() instanceof String uri) { return new ChatModelTypes.MediaContent(mimeType, uri); } else if (media.getData() instanceof byte[] data) { + ChatModelTypes.checkMediaSize(data); return new ChatModelTypes.MediaContent(mimeType, data); } throw new IllegalArgumentException( diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java index 5a86f03ab..d36da0ba3 100644 --- a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java @@ -372,6 +372,7 @@ private ChatModelTypes.MediaContent toMediaContent(Media media) { if (media.getData() instanceof String uri) { return new ChatModelTypes.MediaContent(mimeType, uri); } else if (media.getData() instanceof byte[] data) { + ChatModelTypes.checkMediaSize(data); return new ChatModelTypes.MediaContent(mimeType, data); } throw new IllegalArgumentException( diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java index c1f57317d..9ce9d8517 100644 --- a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import io.temporal.failure.ApplicationFailure; import java.time.Duration; import java.util.List; import javax.annotation.Nullable; @@ -16,6 +17,50 @@ */ public final class ChatModelTypes { + /** + * Maximum size, in bytes, of a single {@link MediaContent#data()} byte array carried across the + * chat activity boundary. Bytes above this threshold land inside workflow history events, which + * have a fixed 2 MiB per-event limit on the Temporal server. 1 MiB leaves headroom for the rest + * of a chat payload (messages, tool definitions, options). + * + *

Users who want to raise or lower the cap can set the system property {@code + * io.temporal.springai.maxMediaBytes} to a positive integer before the chat activity runs; values + * <= 0 disable the guard entirely. For most workloads, pass media by URI instead — write the + * bytes to a binary store from an activity, and pass only the URL across the conversation. + */ + public static final long MAX_MEDIA_BYTES_IN_HISTORY = + Long.getLong("io.temporal.springai.maxMediaBytes", 1L * 1024 * 1024); + + /** Failure type on the {@link ApplicationFailure} thrown by {@link #checkMediaSize(byte[])}. */ + public static final String MEDIA_SIZE_EXCEEDED_FAILURE_TYPE = "MediaSizeExceeded"; + + /** + * Throws a non-retryable {@link ApplicationFailure} if {@code data} exceeds {@link + * #MAX_MEDIA_BYTES_IN_HISTORY}. Non-retryable because this is a permanent, programmer-level error + * — retrying the same oversized payload will never succeed, and using a plain {@link + * RuntimeException} here would cause the workflow task to be retried forever (or the activity to + * churn through its {@code maxAttempts}) rather than surfacing the real problem. The failure + * message points the caller at the URI-based {@code Media} constructor. Pass-through otherwise. + */ + public static void checkMediaSize(byte[] data) { + if (data == null) { + return; + } + long limit = MAX_MEDIA_BYTES_IN_HISTORY; + if (limit > 0 && data.length > limit) { + throw ApplicationFailure.newNonRetryableFailure( + "Media byte[] is " + + data.length + + " bytes, which exceeds the " + + limit + + "-byte limit for inline media in Temporal workflow history. Pass the media by " + + "URI instead: store the bytes outside the workflow (e.g. S3) and construct " + + "Media(mimeType, URI). Set the system property " + + "'io.temporal.springai.maxMediaBytes' to override this limit (or 0 to disable).", + MEDIA_SIZE_EXCEEDED_FAILURE_TYPE); + } + } + private ChatModelTypes() {} /** diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/MediaSizeGuardTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/MediaSizeGuardTest.java new file mode 100644 index 000000000..32d84a508 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/MediaSizeGuardTest.java @@ -0,0 +1,249 @@ +package io.temporal.springai; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowException; +import io.temporal.client.WorkflowOptions; +import io.temporal.failure.ApplicationFailure; +import io.temporal.springai.activity.ChatModelActivityImpl; +import io.temporal.springai.model.ActivityChatModel; +import io.temporal.springai.model.ChatModelTypes; +import io.temporal.testing.TestWorkflowEnvironment; +import io.temporal.worker.Worker; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import java.net.URI; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +/** + * Unit tests around {@link ChatModelTypes#checkMediaSize(byte[])} plus integration-style tests + * against a live TestWorkflowEnvironment to make sure the guard fires on both the inbound (workflow + * → activity) and outbound (activity → workflow) conversion paths. + */ +class MediaSizeGuardTest { + + private static final String TASK_QUEUE = "test-spring-ai-media-size-guard"; + + private TestWorkflowEnvironment testEnv; + private WorkflowClient client; + + @BeforeEach + void setUp() { + testEnv = TestWorkflowEnvironment.newInstance(); + client = testEnv.getWorkflowClient(); + } + + @AfterEach + void tearDown() { + testEnv.close(); + } + + @Test + void checkMediaSize_smallPayload_passes() { + byte[] small = new byte[500 * 1024]; // 500 KiB, well under 1 MiB + assertDoesNotThrow(() -> ChatModelTypes.checkMediaSize(small)); + } + + @Test + void checkMediaSize_oversizedPayload_throwsNonRetryableApplicationFailure() { + byte[] big = new byte[(int) ChatModelTypes.MAX_MEDIA_BYTES_IN_HISTORY + 1]; + ApplicationFailure ex = + assertThrows(ApplicationFailure.class, () -> ChatModelTypes.checkMediaSize(big)); + assertTrue(ex.isNonRetryable(), "guard must throw a non-retryable ApplicationFailure"); + assertEquals(ChatModelTypes.MEDIA_SIZE_EXCEEDED_FAILURE_TYPE, ex.getType()); + String msg = ex.getOriginalMessage(); + assertTrue(msg.contains("URI"), "message should point at the URI alternative: " + msg); + assertTrue( + msg.contains("io.temporal.springai.maxMediaBytes"), + "message should mention the override system property: " + msg); + } + + @Test + void checkMediaSize_null_passes() { + assertDoesNotThrow(() -> ChatModelTypes.checkMediaSize(null)); + } + + @Test + void inboundPath_oversizedUserMessageMedia_failsTheWorkflow() { + // Workflow → activity direction: the workflow builds a Prompt with a huge byte[] media, + // ActivityChatModel.createActivityInput calls toMediaContent → checkMediaSize throws. + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(BigInboundMediaWorkflowImpl.class); + worker.registerActivitiesImplementations(new ChatModelActivityImpl(new StubChatModel())); + testEnv.start(); + + ChatWorkflow workflow = + client.newWorkflowStub( + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + WorkflowException ex = assertThrows(WorkflowException.class, () -> workflow.chat("hi")); + String message = rootMessage(ex); + assertTrue( + message.contains(ChatModelTypes.MEDIA_SIZE_EXCEEDED_FAILURE_TYPE) + || message.contains("-byte limit"), + "expected size-guard failure, got: " + message); + } + + @Test + void inboundPath_smallMedia_passes() { + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(SmallInboundMediaWorkflowImpl.class); + worker.registerActivitiesImplementations(new ChatModelActivityImpl(new StubChatModel())); + testEnv.start(); + + ChatWorkflow workflow = + client.newWorkflowStub( + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + assertEquals("pong", workflow.chat("hi")); + } + + @Test + void inboundPath_uriMedia_passes_regardlessOfSize() { + // URI-based media is not subject to the byte[] guard — bytes stay out of workflow history. + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(UriMediaWorkflowImpl.class); + worker.registerActivitiesImplementations(new ChatModelActivityImpl(new StubChatModel())); + testEnv.start(); + + ChatWorkflow workflow = + client.newWorkflowStub( + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + assertEquals("pong", workflow.chat("hi")); + } + + @Test + void outboundPath_assistantEchoesOversizedMedia_failsTheActivity() { + // Activity → workflow direction: the stub ChatModel returns an assistant message with a + // huge byte[] media, ChatModelActivityImpl.fromMedia → checkMediaSize throws. + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(EchoMediaWorkflowImpl.class); + worker.registerActivitiesImplementations( + new ChatModelActivityImpl(new BigOutboundMediaChatModel())); + testEnv.start(); + + ChatWorkflow workflow = + client.newWorkflowStub( + ChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + WorkflowException ex = assertThrows(WorkflowException.class, () -> workflow.chat("hi")); + String message = rootMessage(ex); + assertTrue( + message.contains("exceeds the") && message.contains("-byte limit"), + "expected size-guard failure on return path, got: " + message); + } + + private static String rootMessage(Throwable t) { + Throwable cur = t; + while (cur.getCause() != null) { + cur = cur.getCause(); + } + return cur.getMessage() == null ? "" : cur.getMessage(); + } + + @WorkflowInterface + public interface ChatWorkflow { + @WorkflowMethod + String chat(String message); + } + + public static class BigInboundMediaWorkflowImpl implements ChatWorkflow { + @Override + public String chat(String message) { + byte[] big = new byte[(int) ChatModelTypes.MAX_MEDIA_BYTES_IN_HISTORY + 1]; + UserMessage userMessage = + UserMessage.builder() + .text(message) + .media(List.of(new Media(MimeTypeUtils.IMAGE_PNG, new ByteArrayResource(big)))) + .build(); + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + return chatModel.call(new Prompt(List.of(userMessage))).getResult().getOutput().getText(); + } + } + + public static class SmallInboundMediaWorkflowImpl implements ChatWorkflow { + @Override + public String chat(String message) { + byte[] small = new byte[16 * 1024]; // 16 KiB + UserMessage userMessage = + UserMessage.builder() + .text(message) + .media(List.of(new Media(MimeTypeUtils.IMAGE_PNG, new ByteArrayResource(small)))) + .build(); + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + return chatModel.call(new Prompt(List.of(userMessage))).getResult().getOutput().getText(); + } + } + + public static class UriMediaWorkflowImpl implements ChatWorkflow { + @Override + public String chat(String message) { + UserMessage userMessage = + UserMessage.builder() + .text(message) + .media( + List.of( + new Media( + MimeTypeUtils.IMAGE_PNG, URI.create("https://cdn.example.com/huge.png")))) + .build(); + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + return chatModel.call(new Prompt(List.of(userMessage))).getResult().getOutput().getText(); + } + } + + public static class EchoMediaWorkflowImpl implements ChatWorkflow { + @Override + public String chat(String message) { + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + return chatModel.call(new Prompt(message)).getResult().getOutput().getText(); + } + } + + /** Returns "pong" — used to verify non-failing paths. */ + private static class StubChatModel implements ChatModel { + @Override + public ChatResponse call(Prompt prompt) { + return ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("pong")))) + .build(); + } + + @Override + public reactor.core.publisher.Flux stream(Prompt prompt) { + throw new UnsupportedOperationException(); + } + } + + /** Returns an assistant message carrying a huge byte[] media, to trip the outbound guard. */ + private static class BigOutboundMediaChatModel implements ChatModel { + @Override + public ChatResponse call(Prompt prompt) { + byte[] big = new byte[(int) ChatModelTypes.MAX_MEDIA_BYTES_IN_HISTORY + 1]; + AssistantMessage assistant = + AssistantMessage.builder() + .content("") + .media(List.of(new Media(MimeType.valueOf("image/png"), new ByteArrayResource(big)))) + .build(); + return ChatResponse.builder().generations(List.of(new Generation(assistant))).build(); + } + + @Override + public reactor.core.publisher.Flux stream(Prompt prompt) { + throw new UnsupportedOperationException(); + } + } +}