diff --git a/.doc_gen/metadata/bedrock-runtime_metadata.yaml b/.doc_gen/metadata/bedrock-runtime_metadata.yaml index 37ed64b2f44..a0c867e1753 100644 --- a/.doc_gen/metadata/bedrock-runtime_metadata.yaml +++ b/.doc_gen/metadata/bedrock-runtime_metadata.yaml @@ -108,6 +108,14 @@ bedrock-runtime_Converse_AmazonNovaText: - description: Send a text message to Amazon Nova, using Bedrock's Converse API. snippet_tags: - javascript.v3.bedrock-runtime.Converse_AmazonNovaText + Kotlin: + versions: + - sdk_version: 1 + github: kotlin/services/bedrock-runtime + excerpts: + - description: Send a text message to Amazon Nova, using Bedrock's Converse API. + snippet_tags: + - bedrock-runtime.kotlin.Converse_AmazonNovaText .NET: versions: - sdk_version: 3 @@ -439,6 +447,14 @@ bedrock-runtime_ConverseStream_AmazonNovaText: - description: Send a text message to Amazon Nova using Bedrock's Converse API and process the response stream in real-time. snippet_tags: - javascript.v3.bedrock-runtime.ConverseStream_AmazonNovaText + Kotlin: + versions: + - sdk_version: 1 + github: kotlin/services/bedrock-runtime + excerpts: + - description: Send a text message to Amazon Nova using Bedrock's Converse API and process the response stream in real-time. + snippet_tags: + - bedrock-runtime.kotlin.ConverseStream_AmazonNovaText .NET: versions: - sdk_version: 3 diff --git a/.gitignore b/.gitignore index 90d910d34d9..5f57c026395 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ kotlin/services/**/build/ kotlin/services/**/gradle/ kotlin/services/**/gradlew kotlin/services/**/gradlew.bat +kotlin/services/**/.kotlin/ diff --git a/kotlin/services/bedrock-runtime/README.md b/kotlin/services/bedrock-runtime/README.md index 2d5be2d977c..cff05f35426 100644 --- a/kotlin/services/bedrock-runtime/README.md +++ b/kotlin/services/bedrock-runtime/README.md @@ -30,9 +30,18 @@ For prerequisites, see the [README](../../README.md#Prerequisites) in the `kotli > ⚠ You must request access to a model before you can use it. If you try to use the model (with the API or console) before you have requested access to it, you will receive an error message. For more information, see [Model access](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html). +### Amazon Nova + +- [Converse](src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/Converse.kt#L6) +- [ConverseStream](src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/ConverseStream.kt#L6) + +### Amazon Nova Canvas + +- [InvokeModel](src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/canvas/InvokeModel.kt#L6) + ### Amazon Titan Text -- [InvokeModel](src/main/kotlin/com/example/bedrockruntime/InvokeModel.kt#L6) +- [InvokeModel](src/main/kotlin/com/example/bedrockruntime/models/amazon/titan/text/InvokeModel.kt#L6) diff --git a/kotlin/services/bedrock-runtime/build.gradle.kts b/kotlin/services/bedrock-runtime/build.gradle.kts index b51ca6849ff..c842e24d1ea 100644 --- a/kotlin/services/bedrock-runtime/build.gradle.kts +++ b/kotlin/services/bedrock-runtime/build.gradle.kts @@ -1,54 +1,41 @@ plugins { kotlin("jvm") version "2.1.10" id("org.jetbrains.kotlin.plugin.serialization") version "2.1.10" - id("org.jlleitschuh.gradle.ktlint") version "11.3.1" apply true + id("org.jlleitschuh.gradle.ktlint") version "12.1.1" application } group = "com.example.bedrockruntime" version = "1.0-SNAPSHOT" +val awsSdkVersion = "1.4.27" +val junitVersion = "5.12.0" + repositories { mavenCentral() } -buildscript { - repositories { - maven("https://plugins.gradle.org/m2/") - } - dependencies { - classpath("org.jlleitschuh.gradle:ktlint-gradle:11.3.1") - } -} - dependencies { - implementation("aws.sdk.kotlin:bedrockruntime:1.4.11") + implementation("aws.sdk.kotlin:bedrockruntime:$awsSdkVersion") implementation("org.jetbrains.kotlinx:kotlinx-serialization-json-jvm:1.8.0") - testImplementation("org.junit.jupiter:junit-jupiter:5.11.4") -} -application { - mainClass.set("com.example.bedrockruntime.InvokeModelKt") + testImplementation("org.junit.jupiter:junit-jupiter:$junitVersion") + testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.1") + testImplementation("org.jetbrains.kotlin:kotlin-reflect") + testRuntimeOnly("org.junit.platform:junit-platform-launcher") } -// Java and Kotlin configuration kotlin { jvmToolchain(21) } -java { - toolchain { - languageVersion = JavaLanguageVersion.of(21) - } -} - tasks.test { useJUnitPlatform() testLogging { events("passed", "skipped", "failed") } +} - // Define the test source set - testClassesDirs += files("build/classes/kotlin/test") - classpath += files("build/classes/kotlin/main", "build/resources/main") +application { + mainClass.set("com.example.bedrockruntime.InvokeModelKt") } diff --git a/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/InvokeModel.kt b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/InvokeModel.kt deleted file mode 100644 index 167bccac5b0..00000000000 --- a/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/InvokeModel.kt +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package com.example.bedrockruntime - -// snippet-start:[bedrock-runtime.kotlin.InvokeModel_AmazonTitanText] -import aws.sdk.kotlin.services.bedrockruntime.BedrockRuntimeClient -import aws.sdk.kotlin.services.bedrockruntime.model.InvokeModelRequest -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.Json - -/** - * Before running this Kotlin code example, set up your development environment, including your credentials. - * - * This example demonstrates how to invoke the Titan Text model (amazon.titan-text-lite-v1). - * Remember that you must enable the model before you can use it. See notes in the README.md file. - * - * For more information, see the following documentation topic: - * https://docs.aws.amazon.com/sdk-for-kotlin/latest/developer-guide/setup.html - */ -suspend fun main() { - val prompt = """ - Write a short, funny story about a time-traveling cat who - ends up in ancient Egypt at the time of the pyramids. - """.trimIndent() - - val response = invokeModel(prompt, "amazon.titan-text-lite-v1") - println("Generated story:\n$response") -} - -suspend fun invokeModel(prompt: String, modelId: String): String { - BedrockRuntimeClient { region = "eu-central-1" }.use { client -> - val request = InvokeModelRequest { - this.modelId = modelId - contentType = "application/json" - accept = "application/json" - body = """ - { - "inputText": "${prompt.replace(Regex("\\s+"), " ").trim()}", - "textGenerationConfig": { - "maxTokenCount": 1000, - "stopSequences": [], - "temperature": 1, - "topP": 0.7 - } - } - """.trimIndent().toByteArray() - } - - val response = client.invokeModel(request) - val responseBody = response.body.toString(Charsets.UTF_8) - - val jsonParser = Json { ignoreUnknownKeys = true } - return jsonParser - .decodeFromString(responseBody) - .results - .first() - .outputText - } -} - -@Serializable -private data class BedrockResponse(val results: List) - -@Serializable -private data class Result(val outputText: String) -// snippet-end:[bedrock-runtime.kotlin.InvokeModel_AmazonTitanText] diff --git a/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/libs/ImageTools.kt b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/libs/ImageTools.kt new file mode 100644 index 00000000000..98f9a72750e --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/libs/ImageTools.kt @@ -0,0 +1,39 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.example.bedrockruntime.libs + +import java.io.ByteArrayInputStream +import java.io.IOException +import javax.imageio.ImageIO +import javax.swing.ImageIcon +import javax.swing.JFrame +import javax.swing.JLabel + +/** + * Utility object for handling image-related operations. + */ +object ImageTools { + /** + * Displays a byte array as an image in a new window. + * + * Creates a new JFrame window that displays the image represented by the provided byte array. + * The window will close the application when closed (EXIT_ON_CLOSE). + * + * @param imageData The image data as a byte array + * @throws RuntimeException if there is an error reading the image data + */ + fun displayImage(imageData: ByteArray) { + try { + val image = ImageIO.read(ByteArrayInputStream(imageData)) + JFrame("Image").apply { + defaultCloseOperation = JFrame.EXIT_ON_CLOSE + contentPane.add(JLabel(ImageIcon(image))) + pack() + isVisible = true + } + } catch (e: IOException) { + throw RuntimeException(e) + } + } +} diff --git a/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/canvas/InvokeModel.kt b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/canvas/InvokeModel.kt new file mode 100644 index 00000000000..a90807ba041 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/canvas/InvokeModel.kt @@ -0,0 +1,94 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.example.bedrockruntime.models.amazon.nova.canvas + +// snippet-start:[bedrock-runtime.kotlin.InvokeModel_AmazonNovaImageGeneration] + +import aws.sdk.kotlin.services.bedrockruntime.BedrockRuntimeClient +import aws.sdk.kotlin.services.bedrockruntime.model.InvokeModelRequest +import com.example.bedrockruntime.libs.ImageTools.displayImage +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import java.util.* + +/** + * This example demonstrates how to use Amazon Nova Canvas to generate images. + * It shows how to: + * - Set up the Amazon Bedrock runtime client + * - Configure the image generation parameters + * - Send a request to generate an image + * - Process the response and display the generated image + */ +suspend fun main() { + println("Generating image. This may take a few seconds...") + val imageData = invokeModel() + displayImage(imageData) +} + +// Data class for parsing the model's response +@Serializable +private data class Response(val images: List) + +// Configure JSON parser to ignore unknown fields in the response +private val json = Json { ignoreUnknownKeys = true } + +suspend fun invokeModel(): ByteArray { + // Create and configure the Bedrock runtime client + BedrockRuntimeClient { region = "us-east-1" }.use { client -> + + // Specify the model ID. For the latest available models, see: + // https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + val modelId = "amazon.nova-canvas-v1:0" + + // Configure the generation parameters and create the request + // First, set the main parameters: + // - prompt: Text description of the image to generate + // - seed: Random number for reproducible generation (0 to 858,993,459) + val prompt = "A stylized picture of a cute old steampunk robot" + val seed = (0..858_993_459).random() + + // Then, create the request using a template with the following structure: + // - taskType: TEXT_IMAGE (specifies text-to-image generation) + // - textToImageParams: Contains the text prompt + // - imageGenerationConfig: Contains optional generation settings (seed, quality, etc.) + // For a list of available request parameters, see: + // https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + val request = """ + { + "taskType": "TEXT_IMAGE", + "textToImageParams": { + "text": "$prompt" + }, + "imageGenerationConfig": { + "seed": $seed, + "quality": "standard" + } + } + """.trimIndent() + + // Send the request and process the model's response + runCatching { + // Send the request to the model + val response = client.invokeModel( + InvokeModelRequest { + this.modelId = modelId + body = request.toByteArray() + }, + ) + + // Parse the response and extract the generated image + val jsonResponse = response.body.toString(Charsets.UTF_8) + val parsedResponse = json.decodeFromString(jsonResponse) + + // Extract the generated image and return it as a byte array for better handling + val base64Image = parsedResponse.images.first() + return Base64.getDecoder().decode(base64Image) + }.getOrElse { error -> + System.err.println("ERROR: Can't invoke '$modelId'. Reason: ${error.message}") + throw RuntimeException("Failed to generate image with model $modelId", error) + } + } +} + +// snippet-end:[bedrock-runtime.kotlin.InvokeModel_AmazonNovaImageGeneration] diff --git a/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/Converse.kt b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/Converse.kt new file mode 100644 index 00000000000..20e196cf5a7 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/Converse.kt @@ -0,0 +1,62 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.example.bedrockruntime.models.amazon.nova.text + +// snippet-start:[bedrock-runtime.kotlin.Converse_AmazonNovaText] + +import aws.sdk.kotlin.services.bedrockruntime.BedrockRuntimeClient +import aws.sdk.kotlin.services.bedrockruntime.model.ContentBlock +import aws.sdk.kotlin.services.bedrockruntime.model.ConversationRole +import aws.sdk.kotlin.services.bedrockruntime.model.ConverseRequest +import aws.sdk.kotlin.services.bedrockruntime.model.Message + +/** + * This example demonstrates how to use the Amazon Nova foundation models to generate text. + * It shows how to: + * - Set up the Amazon Bedrock runtime client + * - Create a message + * - Configure and send a request + * - Process the response + */ +suspend fun main() { + converse().also { println(it) } +} + +suspend fun converse(): String { + // Create and configure the Bedrock runtime client + BedrockRuntimeClient { region = "us-east-1" }.use { client -> + + // Specify the model ID. For the latest available models, see: + // https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + val modelId = "amazon.nova-lite-v1:0" + + // Create the message with the user's prompt + val prompt = "Describe the purpose of a 'hello world' program in one line." + val message = Message { + role = ConversationRole.User + content = listOf(ContentBlock.Text(prompt)) + } + + // Configure the request with optional model parameters + val request = ConverseRequest { + this.modelId = modelId + messages = listOf(message) + inferenceConfig { + maxTokens = 500 // Maximum response length + temperature = 0.5F // Lower values: more focused output + // topP = 0.8F // Alternative to temperature + } + } + + // Send the request and process the model's response + runCatching { + val response = client.converse(request) + return response.output!!.asMessage().content.first().asText() + }.getOrElse { error -> + error.message?.let { e -> System.err.println("ERROR: Can't invoke '$modelId'. Reason: $e") } + throw RuntimeException("Failed to generate text with model $modelId", error) + } + } +} +// snippet-end:[bedrock-runtime.kotlin.Converse_AmazonNovaText] diff --git a/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/ConverseStream.kt b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/ConverseStream.kt new file mode 100644 index 00000000000..932978e7c28 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/nova/text/ConverseStream.kt @@ -0,0 +1,83 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.example.bedrockruntime.models.amazon.nova.text + +// snippet-start:[bedrock-runtime.kotlin.ConverseStream_AmazonNovaText] + +import aws.sdk.kotlin.services.bedrockruntime.BedrockRuntimeClient +import aws.sdk.kotlin.services.bedrockruntime.model.ContentBlock +import aws.sdk.kotlin.services.bedrockruntime.model.ConversationRole +import aws.sdk.kotlin.services.bedrockruntime.model.ConverseStreamOutput +import aws.sdk.kotlin.services.bedrockruntime.model.ConverseStreamRequest +import aws.sdk.kotlin.services.bedrockruntime.model.Message + +/** + * This example demonstrates how to use the Amazon Nova foundation models + * to generate streaming text responses. + * It shows how to: + * - Set up the Amazon Bedrock runtime client + * - Create a message with a prompt + * - Configure a streaming request with parameters + * - Process the response stream in real time + */ +suspend fun main() { + converseStream() +} + +suspend fun converseStream(): String { + // A buffer to collect the complete response + val completeResponseBuffer = StringBuilder() + + // Create and configure the Bedrock runtime client + BedrockRuntimeClient { region = "us-east-1" }.use { client -> + + // Specify the model ID. For the latest available models, see: + // https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + val modelId = "amazon.nova-lite-v1:0" + + // Create the message with the user's prompt + val prompt = "Describe the purpose of a 'hello world' program in a paragraph." + val message = Message { + role = ConversationRole.User + content = listOf(ContentBlock.Text(prompt)) + } + + // Configure the request with optional model parameters + val request = ConverseStreamRequest { + this.modelId = modelId + messages = listOf(message) + inferenceConfig { + maxTokens = 500 // Maximum response length + temperature = 0.5F // Lower values: more focused output + // topP = 0.8F // Alternative to temperature + } + } + + // Process the streaming response + runCatching { + client.converseStream(request) { response -> + response.stream?.collect { chunk -> + when (chunk) { + is ConverseStreamOutput.ContentBlockDelta -> { + // Process each text chunk as it arrives + chunk.value.delta?.asText()?.let { text -> + print(text) + System.out.flush() // Ensure immediate output + completeResponseBuffer.append(text) + } + } + else -> {} // Other output block types can be handled as needed + } + } + } + }.onFailure { error -> + error.message?.let { e -> System.err.println("ERROR: Can't invoke '$modelId'. Reason: $e") } + throw RuntimeException("Failed to generate text with model $modelId: $error", error) + } + } + + return completeResponseBuffer.toString() +} + +// snippet-end:[bedrock-runtime.kotlin.ConverseStream_AmazonNovaText] diff --git a/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/titan/text/InvokeModel.kt b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/titan/text/InvokeModel.kt new file mode 100644 index 00000000000..d5ae1b07398 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/main/kotlin/com/example/bedrockruntime/models/amazon/titan/text/InvokeModel.kt @@ -0,0 +1,86 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.example.bedrockruntime.models.amazon.titan.text + +// snippet-start:[bedrock-runtime.kotlin.InvokeModel_AmazonTitanText] + +import aws.sdk.kotlin.services.bedrockruntime.BedrockRuntimeClient +import aws.sdk.kotlin.services.bedrockruntime.model.InvokeModelRequest +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json + +/** + * This example demonstrates how to use the Amazon Titan foundation models to generate text. + * It shows how to: + * - Set up the Amazon Bedrock runtime client + * - Create a request payload + * - Configure and send a request + * - Process the response + */ +suspend fun main() { + invokeModel().also { println(it) } +} + +// Data class for parsing the model's response +@Serializable +private data class BedrockResponse(val results: List) { + @Serializable + data class Result( + val outputText: String, + ) +} + +// Initialize JSON parser with relaxed configuration +private val json = Json { ignoreUnknownKeys = true } + +suspend fun invokeModel(): String { + // Create and configure the Bedrock runtime client + BedrockRuntimeClient { region = "us-east-1" }.use { client -> + + // Specify the model ID. For the latest available models, see: + // https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + val modelId = "amazon.titan-text-lite-v1" + + // Create the request payload with optional configuration parameters + // For detailed parameter descriptions, see: + // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + val prompt = "Describe the purpose of a 'hello world' program in one line." + val request = """ + { + "inputText": "$prompt", + "textGenerationConfig": { + "maxTokenCount": 500, + "temperature": 0.5 + } + } + """.trimIndent() + + // Send the request and process the model's response + runCatching { + // Send the request to the model + val response = client.invokeModel( + InvokeModelRequest { + this.modelId = modelId + body = request.toByteArray() + }, + ) + + // Convert the response bytes to a JSON string + val jsonResponse = response.body.toString(Charsets.UTF_8) + + // Parse the JSON into a Kotlin object + val parsedResponse = json.decodeFromString(jsonResponse) + + // Extract and return the generated text + return parsedResponse.results.firstOrNull()!!.outputText + }.getOrElse { error -> + error.message?.let { msg -> + System.err.println("ERROR: Can't invoke '$modelId'. Reason: $msg") + } + throw RuntimeException("Failed to generate text with model $modelId", error) + } + } +} + +// snippet-end:[bedrock-runtime.kotlin.InvokeModel_AmazonTitanText] diff --git a/kotlin/services/bedrock-runtime/src/test/kotlin/InvokeModelTest.kt b/kotlin/services/bedrock-runtime/src/test/kotlin/InvokeModelTest.kt deleted file mode 100644 index 21d6eb43eb1..00000000000 --- a/kotlin/services/bedrock-runtime/src/test/kotlin/InvokeModelTest.kt +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -import com.example.bedrockruntime.invokeModel -import kotlinx.coroutines.runBlocking -import org.junit.jupiter.api.Assertions.assertTrue -import org.junit.jupiter.api.MethodOrderer.OrderAnnotation -import org.junit.jupiter.api.Order -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.TestInstance -import org.junit.jupiter.api.TestMethodOrder - -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -@TestMethodOrder(OrderAnnotation::class) -class InvokeModelTest { - @Test - @Order(1) - fun listFoundationModels() = runBlocking { - val prompt = "What is the capital of France?" - - val answer = invokeModel(prompt, "amazon.titan-text-lite-v1") - assertTrue(answer.isNotBlank()) - } -} diff --git a/kotlin/services/bedrock-runtime/src/test/kotlin/models/AbstractModelTest.kt b/kotlin/services/bedrock-runtime/src/test/kotlin/models/AbstractModelTest.kt new file mode 100644 index 00000000000..ccd8ee85a79 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/test/kotlin/models/AbstractModelTest.kt @@ -0,0 +1,86 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNotEquals +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.fail +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import java.util.stream.Stream + +/** + * Abstract base class for testing Bedrock model invocations. + * + * Example usage: + * ``` + * class TestMyModel : AbstractModelTest() { + * override fun modelProvider(): Stream { + * return listOf( + * ModelTest("My Model", ::myModelFunction) + * ).stream() + * } + * } + * ``` + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +abstract class AbstractModelTest { + + /** + * Provides the model test configurations to be executed. + * Implementing classes should override this to return a stream of [ModelTest] instances, + * each representing a specific model function to be tested. + * + * @return A stream of [ModelTest] configurations + */ + protected abstract fun modelProvider(): Stream + + /** + * Executes the test for a given model configuration. + * This method runs the invocation function and validates its output using [validateResult]. + * + * @param model The [ModelTest] configuration to execute + */ + @ParameterizedTest(name = "Test {0}") + @MethodSource("modelProvider") + fun testModel(model: ModelTest) = runBlocking { + try { + val result = model.function.invoke() + validateResult(result, model.name) + } catch (e: Exception) { + fail("Test failed for ${model.name}: ${e.message}", e) + } + } + + /** + * Validates the result returned by a model invocation. + * Default implementation ensures that String results are non-empty and ByteArray results have non-zero length. + * Subclasses can override this to implement custom validation logic. + * + * @param result The result returned by the model invocation + * @param modelName The name of the model being tested + * @throws AssertionError if the result is invalid + */ + protected open fun validateResult(result: Any?, modelName: String) { + when (result) { + is String -> assertFalse(result.trim().isEmpty()) { "Empty result from $modelName" } + is ByteArray -> assertNotEquals(0, result.size) { "Empty result from $modelName" } + else -> fail("Unexpected result type from $modelName: ${result?.javaClass}") + } + } + + /** + * Data class representing a model test configuration. + * Encapsulates a model invocation function and its descriptive name. + * + * @property name The descriptive name of the model being tested + * @property function The suspend function to be tested, which should return a String + */ + data class ModelTest( + val name: String, + val function: suspend () -> Any, + ) +} diff --git a/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestConverse.kt b/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestConverse.kt new file mode 100644 index 00000000000..7fe40b8ace3 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestConverse.kt @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import java.util.stream.Stream + +/** + * Test class for text generation on Amazon Bedrock using the Converse API. + */ +class TestConverse : AbstractModelTest() { + /** + * Provides test configurations for Amazon Bedrock text generation models. + * Creates test cases that validate each model's ability to generate + * and return text responses. + */ + override fun modelProvider(): Stream = listOf( + ModelTest("Amazon Nova") { com.example.bedrockruntime.models.amazon.nova.text.converse() }, + ).stream() +} diff --git a/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestConverseStream.kt b/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestConverseStream.kt new file mode 100644 index 00000000000..91a067df459 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestConverseStream.kt @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import java.util.stream.Stream + +/** + * Test class for streaming text generation on Amazon Bedrock using the ConverseStream API. + */ +class TestConverseStream : AbstractModelTest() { + /** + * Provides test configurations for Amazon Bedrock models that support streaming. + * Creates test cases that validate each model's ability to generate + * and return streaming text responses. + */ + override fun modelProvider(): Stream = listOf( + ModelTest("Amazon Nova") { com.example.bedrockruntime.models.amazon.nova.text.converseStream() }, + ).stream() +} diff --git a/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestInvokeModel.kt b/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestInvokeModel.kt new file mode 100644 index 00000000000..2c14291e782 --- /dev/null +++ b/kotlin/services/bedrock-runtime/src/test/kotlin/models/TestInvokeModel.kt @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0v + +package models + +import java.util.stream.Stream + +/** + * Test class for generative AI models on Amazon Bedrock using the InvokeModel API. + */ +class TestInvokeModel : AbstractModelTest() { + /** + * Provides test configurations for generative AI models on Amazon Bedrock. + * Creates test cases that validate each model's ability to generate + * and return text or byte[] responses. + */ + override fun modelProvider(): Stream = listOf( + ModelTest("Amazon Titan Text") { com.example.bedrockruntime.models.amazon.titan.text.invokeModel() }, + ModelTest("Amazon Nova Canvas") { com.example.bedrockruntime.models.amazon.nova.canvas.invokeModel() }, + ).stream() +}