Skip to content

Commit d8ac73a

Browse files
committed
Treeshake s2c messages to Kotlin SDK
Implement coordinated component and message pruning via withPruning, propagating allowedMessages from A2uiSchemaManager down to Catalog. Add robust automated unit tests and enable full conformance verification. Port of Python SDK commit 0fd7240
1 parent 263339c commit d8ac73a

7 files changed

Lines changed: 396 additions & 51 deletions

File tree

agent_sdks/kotlin/src/main/kotlin/com/google/a2ui/core/InferenceStrategy.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ interface InferenceStrategy {
3030
* @param uiDescription Optional UI context or descriptive instruction.
3131
* @param clientUiCapabilities Capabilities reported by the client for targeted schema pruning.
3232
* @param allowedComponents A specific list of component IDs allowed for rendering.
33+
* @param allowedMessages A specific list of message IDs allowed for rendering.
3334
* @param includeSchema Whether to embed the A2UI JSON schema directly in the instructions.
3435
* @param includeExamples Whether to embed few-shot examples in the instructions.
3536
* @param validateExamples Whether to preemptively validate loaded examples against the schema.
@@ -41,6 +42,7 @@ interface InferenceStrategy {
4142
uiDescription: String = "",
4243
clientUiCapabilities: kotlinx.serialization.json.JsonObject? = null,
4344
allowedComponents: List<String> = emptyList(),
45+
allowedMessages: List<String> = emptyList(),
4446
includeSchema: Boolean = false,
4547
includeExamples: Boolean = false,
4648
validateExamples: Boolean = false,

agent_sdks/kotlin/src/main/kotlin/com/google/a2ui/core/parser/StreamingParser.kt

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ abstract class StreamingParser(protected val catalog: A2uiCatalog?) {
387387
}
388388

389389
protected fun processJsonChunk(chunk: String, messages: MutableList<ResponsePart>) {
390+
if (jsonBuffer.length + chunk.length > MAX_JSON_BUFFER_SIZE) {
391+
throw IllegalArgumentException("A2UI JSON buffer exceeded maximum size limit.")
392+
}
390393
for (i in chunk.indices) {
391394
val char = chunk[i]
392395
var charHandled = false
@@ -504,7 +507,7 @@ abstract class StreamingParser(protected val catalog: A2uiCatalog?) {
504507
}
505508
}
506509

507-
if (braceCount > 0 && char in listOf('"', ':', ',', '}', ']')) {
510+
if (braceCount > 0 && (char == '"' || char == ':' || char == ',' || char == '}' || char == ']')) {
508511
sniffMetadata()
509512
}
510513
}
@@ -566,6 +569,7 @@ abstract class StreamingParser(protected val catalog: A2uiCatalog?) {
566569
obj != null && obj["id"]?.jsonPrimitive?.content != null && obj.containsKey("component")
567570
) {
568571
handlePartialComponent(obj, messages)
572+
break
569573
}
570574
} catch (e: Exception) {
571575
logger.warning { e.message }
@@ -591,9 +595,9 @@ abstract class StreamingParser(protected val catalog: A2uiCatalog?) {
591595
try {
592596
obj = Json.parseToJsonElement(fixedFragment) as? JsonObject
593597
} catch (_: Exception) {
594-
var trimmed = rawFragment
595-
while ("," in trimmed) {
596-
trimmed = trimmed.substringBeforeLast(",")
598+
var commaIdx = rawFragment.lastIndexOf(',')
599+
while (commaIdx != -1) {
600+
val trimmed = rawFragment.substring(0, commaIdx)
597601
try {
598602
val fixedTrimmed = fixJson(trimmed)
599603
if (fixedTrimmed.isNotEmpty()) {
@@ -602,8 +606,8 @@ abstract class StreamingParser(protected val catalog: A2uiCatalog?) {
602606
}
603607
} catch (ex: Exception) {
604608
logger.warning { ex.message }
605-
continue
606609
}
610+
commaIdx = rawFragment.lastIndexOf(',', commaIdx - 1)
607611
}
608612
}
609613

@@ -1054,7 +1058,7 @@ abstract class StreamingParser(protected val catalog: A2uiCatalog?) {
10541058
private val PREV_KEY_MATCHES_REGEX = Regex("\"key\"\\s*:\\s*\"([^\"]+)\"")
10551059
private val SURFACE_ID_REGEX = Regex("\"surfaceId\"\\s*:\\s*\"([^\"]+)\"")
10561060
private val ROOT_ID_REGEX = Regex("\"root\"\\s*:\\s*\"([^\"]+)\"")
1057-
internal val JSON_NON_PRETTY = Json { prettyPrint = false }
1061+
private const val MAX_JSON_BUFFER_SIZE = 5 * 1024 * 1024
10581062

10591063
/** Factory method returning a version-specific parser instance. */
10601064
fun create(catalog: A2uiCatalog? = null): StreamingParser {

agent_sdks/kotlin/src/main/kotlin/com/google/a2ui/core/schema/A2uiSchemaManager.kt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,15 @@ constructor(
176176

177177
/**
178178
* Resolves the desired catalog based on the client capabilities, returning it with pruned unused
179-
* components.
179+
* components and messages.
180180
*/
181181
@JvmOverloads
182182
fun getSelectedCatalog(
183183
clientUiCapabilities: JsonObject? = null,
184184
allowedComponents: List<String> = emptyList(),
185-
): A2uiCatalog = selectCatalog(clientUiCapabilities).withPrunedComponents(allowedComponents)
185+
allowedMessages: List<String> = emptyList(),
186+
): A2uiCatalog =
187+
selectCatalog(clientUiCapabilities).withPruning(allowedComponents, allowedMessages)
186188

187189
/** Renders LLM examples for a given catalog, loaded from its configured examples path. */
188190
@JvmOverloads
@@ -197,6 +199,7 @@ constructor(
197199
uiDescription: String,
198200
clientUiCapabilities: JsonObject?,
199201
allowedComponents: List<String>,
202+
allowedMessages: List<String>,
200203
includeSchema: Boolean,
201204
includeExamples: Boolean,
202205
validateExamples: Boolean,
@@ -212,7 +215,8 @@ constructor(
212215
parts.add("## UI Description:\n$uiDescription")
213216
}
214217

215-
val selectedCatalog = getSelectedCatalog(clientUiCapabilities, allowedComponents)
218+
val selectedCatalog =
219+
getSelectedCatalog(clientUiCapabilities, allowedComponents, allowedMessages)
216220

217221
if (includeSchema) {
218222
parts.add(selectedCatalog.renderAsLlmInstructions())

agent_sdks/kotlin/src/main/kotlin/com/google/a2ui/core/schema/Catalog.kt

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,28 @@ data class A2uiCatalog(
7272
}
7373

7474
/**
75-
* Returns a new catalog with only allowed components.
75+
* Returns a new catalog with pruned components and messages.
7676
*
7777
* @param allowedComponents List of component names to include.
78-
* @return A copy of the catalog with only allowed components.
78+
* @param allowedMessages List of message names to include in serverToClientSchema.
79+
* @return A copy of the catalog with pruned components and messages.
7980
*/
80-
fun withPrunedComponents(allowedComponents: List<String>): A2uiCatalog {
81-
if (allowedComponents.isEmpty()) return this.withPrunedCommonTypes()
81+
fun withPruning(
82+
allowedComponents: List<String>? = null,
83+
allowedMessages: List<String>? = null,
84+
): A2uiCatalog {
85+
var catalog = this
86+
if (allowedComponents != null) {
87+
catalog = catalog.withPrunedComponentsInternal(allowedComponents)
88+
}
89+
if (allowedMessages != null) {
90+
catalog = catalog.withPrunedMessages(allowedMessages)
91+
}
92+
return catalog.withPrunedCommonTypes()
93+
}
94+
95+
private fun withPrunedComponentsInternal(allowedComponents: List<String>): A2uiCatalog {
96+
if (allowedComponents.isEmpty()) return this
8297

8398
val schemaCopy = catalogSchema.toMutableMap()
8499

@@ -97,68 +112,114 @@ data class A2uiCatalog(
97112
}
98113
}
99114

100-
return copy(catalogSchema = JsonObject(schemaCopy)).withPrunedCommonTypes()
115+
return copy(catalogSchema = JsonObject(schemaCopy))
116+
}
117+
118+
private fun withPrunedMessages(allowedMessages: List<String>): A2uiCatalog {
119+
if (allowedMessages.isEmpty()) return this
120+
121+
val s2cCopy = serverToClientSchema.toMutableMap()
122+
123+
if (version == A2uiVersion.VERSION_0_8) {
124+
(s2cCopy["properties"] as? JsonObject)?.let { props ->
125+
s2cCopy["properties"] =
126+
pruneDefsByReachability(
127+
defs = props,
128+
rootDefNames = allowedMessages,
129+
internalRefPrefix = "#/properties/",
130+
)
131+
}
132+
} else {
133+
(s2cCopy["oneOf"] as? JsonArray)?.let { oneOf ->
134+
val filteredOneOf =
135+
oneOf.filter { item ->
136+
val ref = (item as? JsonObject)?.get("\$ref")?.jsonPrimitive?.content
137+
ref != null && ref.startsWith("#/\$defs/") && ref.split("/").last() in allowedMessages
138+
}
139+
s2cCopy["oneOf"] = JsonArray(filteredOneOf)
140+
}
141+
142+
(s2cCopy["\$defs"] as? JsonObject)?.let { defs ->
143+
s2cCopy["\$defs"] =
144+
pruneDefsByReachability(
145+
defs = defs,
146+
rootDefNames = allowedMessages,
147+
internalRefPrefix = "#/\$defs/",
148+
)
149+
}
150+
}
151+
152+
return copy(serverToClientSchema = JsonObject(s2cCopy))
101153
}
102154

103155
/** Returns a new catalog with unused common types pruned from the schema. */
104156
fun withPrunedCommonTypes(): A2uiCatalog {
105157
val defs = commonTypesSchema["\$defs"] as? JsonObject ?: return this
106158
if (defs.isEmpty()) return this
107159

108-
fun collectRefs(element: JsonElement, refs: MutableSet<String>) {
109-
when (element) {
160+
val externalRefs = mutableSetOf<String>()
161+
collectRefs(catalogSchema, externalRefs)
162+
collectRefs(serverToClientSchema, externalRefs)
163+
164+
val prefix = "common_types.json#/\$defs/"
165+
val rootDefs =
166+
externalRefs.mapNotNull { if (it.startsWith(prefix)) it.substring(prefix.length) else null }
167+
168+
val newDefs = pruneDefsByReachability(defs, rootDefs)
169+
val newCommonTypes =
170+
JsonObject(commonTypesSchema.toMutableMap().apply { put("\$defs", newDefs) })
171+
172+
return copy(commonTypesSchema = newCommonTypes)
173+
}
174+
175+
private fun collectRefs(rootElement: JsonElement, refs: MutableSet<String>) {
176+
val stack = ArrayDeque<JsonElement>()
177+
stack.addLast(rootElement)
178+
179+
while (stack.isNotEmpty()) {
180+
when (val element = stack.removeLast()) {
110181
is JsonObject -> {
111182
for ((k, v) in element) {
112183
if (k == "\$ref" && v is JsonPrimitive && v.isString) {
113184
refs.add(v.content)
114185
} else {
115-
collectRefs(v, refs)
186+
stack.addLast(v)
116187
}
117188
}
118189
}
119190
is JsonArray -> {
120191
for (item in element) {
121-
collectRefs(item, refs)
192+
stack.addLast(item)
122193
}
123194
}
124195
else -> {}
125196
}
126197
}
198+
}
127199

200+
private fun pruneDefsByReachability(
201+
defs: JsonObject,
202+
rootDefNames: List<String>,
203+
internalRefPrefix: String = "#/\$defs/",
204+
): JsonObject {
128205
val visitedDefs = mutableSetOf<String>()
129-
val queue = ArrayDeque<String>()
130-
131-
val externalRefs = mutableSetOf<String>()
132-
collectRefs(catalogSchema, externalRefs)
133-
collectRefs(serverToClientSchema, externalRefs)
134-
135-
val prefix = "common_types.json#/\$defs/"
136-
for (ref in externalRefs) {
137-
if (ref.startsWith(prefix)) {
138-
queue.add(ref.substring(prefix.length))
139-
}
140-
}
206+
val queue = ArrayDeque(rootDefNames)
141207

142208
while (queue.isNotEmpty()) {
143209
val defName = queue.removeFirst()
144210
if (defs.containsKey(defName) && visitedDefs.add(defName)) {
145211
val defElement = defs[defName]!!
146212
val internalRefs = mutableSetOf<String>()
147213
collectRefs(defElement, internalRefs)
148-
val internalPrefix = "#/\$defs/"
149214
for (ref in internalRefs) {
150-
if (ref.startsWith(internalPrefix)) {
151-
queue.add(ref.substring(internalPrefix.length))
215+
if (ref.startsWith(internalRefPrefix)) {
216+
queue.add(ref.substring(internalRefPrefix.length))
152217
}
153218
}
154219
}
155220
}
156221

157-
val newDefs = JsonObject(defs.filterKeys { it in visitedDefs })
158-
val newCommonTypes =
159-
JsonObject(commonTypesSchema.toMutableMap().apply { put("\$defs", newDefs) })
160-
161-
return copy(commonTypesSchema = newCommonTypes)
222+
return JsonObject(defs.filterKeys { it in visitedDefs })
162223
}
163224

164225
private fun pruneAnyComponentOneOf(

agent_sdks/kotlin/src/test/kotlin/com/google/a2ui/conformance/ConformanceTest.kt

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,16 @@ class ConformanceTest {
119119
val version =
120120
if (versionStr == VERSION_0_8_STR) A2uiVersion.VERSION_0_8 else A2uiVersion.VERSION_0_9
121121

122-
val s2cSchemaFile = catalogMap["s2c_schema"] as? String
122+
val s2cSchemaObj = catalogMap["s2c_schema"]
123123
val s2cSchema =
124-
s2cSchemaFile?.let { loadJsonFile(File(conformanceDir, it)) } ?: JsonObject(emptyMap())
124+
if (s2cSchemaObj is String) {
125+
loadJsonFile(File(conformanceDir, s2cSchemaObj))
126+
} else if (s2cSchemaObj is Map<*, *>) {
127+
val jsonStr = jsonMapper.writeValueAsString(s2cSchemaObj)
128+
Json.parseToJsonElement(jsonStr) as JsonObject
129+
} else {
130+
JsonObject(emptyMap())
131+
}
125132

126133
val catalogSchemaObj = catalogMap["catalog_schema"]
127134
val schemaMappings = HashMap(baseSchemaMappings)
@@ -152,9 +159,16 @@ class ConformanceTest {
152159
)
153160
}
154161

155-
val commonTypesFile = catalogMap["common_types_schema"] as? String
162+
val commonTypesObj = catalogMap["common_types_schema"]
156163
val commonTypesSchema =
157-
commonTypesFile?.let { loadJsonFile(File(conformanceDir, it)) } ?: JsonObject(emptyMap())
164+
if (commonTypesObj is String) {
165+
loadJsonFile(File(conformanceDir, commonTypesObj))
166+
} else if (commonTypesObj is Map<*, *>) {
167+
val jsonStr = jsonMapper.writeValueAsString(commonTypesObj)
168+
Json.parseToJsonElement(jsonStr) as JsonObject
169+
} else {
170+
JsonObject(emptyMap())
171+
}
158172

159173
val catalog =
160174
A2uiCatalog(
@@ -227,12 +241,6 @@ class ConformanceTest {
227241
val args = case[ConformanceTestHelper.KEY_ARGS] as? Map<*, *> ?: emptyMap<Any, Any>()
228242

229243
// Filter out non-conformant tests for Kotlin
230-
if (
231-
action == "prune" && (args.containsKey("allowed_messages") || name.contains("common_types"))
232-
) {
233-
println("Skipping non-conformant test (prune messages/common_types): $name")
234-
return@mapNotNull null
235-
}
236244
if (
237245
action == "load" &&
238246
(args[KEY_PATH] as? String)?.let {
@@ -257,13 +265,22 @@ class ConformanceTest {
257265

258266
when (action) {
259267
"prune" -> {
260-
val allowedComponents = args[KEY_ALLOWED_COMPONENTS] as? List<String> ?: emptyList()
261-
val pruned = catalog!!.withPrunedComponents(allowedComponents)
268+
val allowedComponents = args[KEY_ALLOWED_COMPONENTS] as? List<String>
269+
val allowedMessages = args["allowed_messages"] as? List<String>
270+
val pruned = catalog!!.withPruning(allowedComponents, allowedMessages)
262271
val expect = case[ConformanceTestHelper.KEY_EXPECT] as Map<*, *>
263272
if (expect.containsKey(KEY_CATALOG_SCHEMA)) {
264273
val expectSchema = jsonMapper.writeValueAsString(expect[KEY_CATALOG_SCHEMA])
265274
assertEquals(Json.parseToJsonElement(expectSchema), pruned.catalogSchema)
266275
}
276+
if (expect.containsKey("s2c_schema")) {
277+
val expectSchema = jsonMapper.writeValueAsString(expect["s2c_schema"])
278+
assertEquals(Json.parseToJsonElement(expectSchema), pruned.serverToClientSchema)
279+
}
280+
if (expect.containsKey("common_types_schema")) {
281+
val expectSchema = jsonMapper.writeValueAsString(expect["common_types_schema"])
282+
assertEquals(Json.parseToJsonElement(expectSchema), pruned.commonTypesSchema)
283+
}
267284
}
268285
"load" -> {
269286
val path = args[KEY_PATH] as? String

agent_sdks/kotlin/src/test/kotlin/com/google/a2ui/core/parser/StreamingParserTest.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package com.google.a2ui.core.parser
1919
import com.google.a2ui.core.schema.A2uiConstants
2020
import kotlin.test.Test
2121
import kotlin.test.assertEquals
22+
import kotlin.test.assertFailsWith
2223
import kotlin.test.assertNotNull
2324
import kotlin.test.assertTrue
2425
import kotlinx.serialization.json.JsonArray
@@ -159,4 +160,14 @@ class StreamingParserTest {
159160

160161
assertEquals("/absolute/path", pathStr)
161162
}
163+
164+
@Test
165+
fun throwsExceptionWhenJsonBufferExceedsMaxSizeLimit() {
166+
val parser = StreamingParser.create(null)
167+
parser.processChunk(A2uiConstants.A2UI_OPEN_TAG)
168+
val hugeChunk = String(CharArray(5 * 1024 * 1024 + 1) { ' ' })
169+
assertFailsWith<IllegalArgumentException> {
170+
parser.processChunk(hugeChunk)
171+
}
172+
}
162173
}

0 commit comments

Comments
 (0)