Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class CounterKt {

@Handler
@Shared
suspend fun get(ctx: SharedObjectContext): Long {
return ctx.get(TOTAL) ?: 0L
suspend fun get(ctx: SharedObjectContext): Long? {
return ctx.get(TOTAL)
}

@Handler
Expand Down
19 changes: 19 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,24 @@
[libraries.victools-jsonschema-module-jackson.version]
ref = 'victools-json-schema'

[libraries.schema-kenerator-core]
module = 'io.github.smiley4:schema-kenerator-core'

[libraries.schema-kenerator-core.version]
ref = 'schema-kenerator'

[libraries.schema-kenerator-serialization]
module = 'io.github.smiley4:schema-kenerator-serialization'

[libraries.schema-kenerator-serialization.version]
ref = 'schema-kenerator'

[libraries.schema-kenerator-jsonschema]
module = 'io.github.smiley4:schema-kenerator-jsonschema'

[libraries.schema-kenerator-jsonschema.version]
ref = 'schema-kenerator'

[plugins]
aggregate-javadoc = 'io.freefair.aggregate-javadoc:8.6'
dependency-license-report = 'com.github.jk1.dependency-license-report:2.0'
Expand Down Expand Up @@ -213,3 +231,4 @@
spring-boot = '3.4.4'
vertx = '4.5.11'
victools-json-schema = '4.37.0'
schema-kenerator = '2.1.2'
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.node.ObjectNode;
import dev.restate.serde.Serde;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

class JacksonSerdesTest {

record Recursive(String value, Recursive rec) {}

public static class Person {

private final String name;
Expand Down Expand Up @@ -75,4 +79,11 @@ private static Stream<Arguments> roundtripTestCases() {
<T> void roundtrip(T value, Serde<T> serde) {
assertThat(serde.deserialize(serde.serialize(value))).isEqualTo(value);
}

@Test
void schemaGenWorksWithRecursion() {
ObjectNode node =
(ObjectNode) ((Serde.JsonSchema) JacksonSerdes.of(Recursive.class).jsonSchema()).schema();
assertThat(node.at("/properties/rec/$ref").textValue()).isEqualTo("#");
}
}
6 changes: 6 additions & 0 deletions sdk-serde-kotlinx/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ description = "Restate SDK Kotlinx Serialization integration"
dependencies {
api(libs.kotlinx.serialization.json)
implementation(libs.kotlinx.serialization.core)
implementation(libs.schema.kenerator.core)
implementation(libs.schema.kenerator.serialization)
implementation(libs.schema.kenerator.jsonschema)

implementation(project(":common"))

testImplementation(libs.junit.jupiter)
testImplementation(libs.assertj)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.serde.kotlinx

import dev.restate.serde.Serde
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.compileReferencing
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.generateJsonSchema
import io.github.smiley4.schemakenerator.jsonschema.TitleBuilder
import io.github.smiley4.schemakenerator.jsonschema.data.IntermediateJsonSchemaData
import io.github.smiley4.schemakenerator.jsonschema.data.RefType
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonArray
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonNode
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonObject
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonTextValue
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.array
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.analyzeTypeUsingKotlinxSerialization
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.initial
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.renameMembers
import kotlin.collections.set
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.Json

object DefaultJsonSchemaFactory : KotlinSerializationSerdeFactory.JsonSchemaFactory {
@OptIn(ExperimentalSerializationApi::class)
override fun generateSchema(json: Json, serializer: KSerializer<*>) =
Serde.StringifiedJsonSchema(
runCatching {
var initialStep =
initial(serializer.descriptor).analyzeTypeUsingKotlinxSerialization {
serializersModule = json.serializersModule
}

if (json.configuration.namingStrategy != null) {
initialStep = initialStep.renameMembers(json.configuration.namingStrategy!!)
}

val intermediateStep =
initialStep.generateJsonSchema {
optionalHandling = JsonSchemaSteps.OptionalHandling.NON_REQUIRED
}
intermediateStep.writeTitles()
val compiledSchema = intermediateStep.compileReferencing(RefType.SIMPLE)

// In case of nested schemas, compileReferencing also contains self schema...
val rootSchemaName =
TitleBuilder.BUILDER_SIMPLE(
compiledSchema.typeData, intermediateStep.typeDataById)

// If schema is not json object, then it's boolean, so we're good no need for
// additional manipulation
if (compiledSchema.json !is JsonObject) {
return@runCatching compiledSchema.json
}

// Assemble the final schema now
val rootNode = compiledSchema.json as JsonObject
// Add $schema
rootNode.properties.put(
"\$schema", JsonTextValue("https://json-schema.org/draft/2020-12/schema"))
// Add $defs
val definitions =
compiledSchema.definitions.filter { it.key != rootSchemaName }.toMutableMap()
if (definitions.isNotEmpty()) {
rootNode.properties.put("\$defs", JsonObject(definitions))
}
// Replace all $refs
rootNode.fixRefsPrefix("#/definitions/$rootSchemaName")
// If the root type is nullable, it should be in the schema too
if (serializer.descriptor.isNullable) {
val oldTypeProperty = rootNode.properties["type"]
if (oldTypeProperty is JsonTextValue) {
rootNode.properties["type"] = array {
item(oldTypeProperty.value)
item(JsonTextValue("null"))
}
} else if (oldTypeProperty is JsonArray) {
oldTypeProperty.items.add(JsonTextValue("null"))
}
}

return@runCatching rootNode
}
.getOrDefault(JsonObject(mutableMapOf()))
.prettyPrint())

private fun IntermediateJsonSchemaData.writeTitles() {
this.entries.forEach { schema ->
if (schema.json is JsonObject) {
if ((schema.typeData.isMap ||
schema.typeData.isCollection ||
schema.typeData.isEnum ||
schema.typeData.isInlineValue ||
schema.typeData.typeParameters.isNotEmpty() ||
schema.typeData.members.isNotEmpty()) &&
(schema.json as JsonObject).properties["title"] == null) {
(schema.json as JsonObject).properties["title"] =
JsonTextValue(TitleBuilder.BUILDER_SIMPLE(schema.typeData, this.typeDataById))
}
}
}
}

private fun JsonNode.fixRefsPrefix(rootDefinition: String) {
when (this) {
is JsonArray -> this.items.forEach { it.fixRefsPrefix(rootDefinition) }
is JsonObject -> this.fixRefsPrefix(rootDefinition)
else -> {}
}
}

private fun JsonObject.fixRefsPrefix(rootDefinition: String) {
this.properties.computeIfPresent("\$ref") { key, node ->
if (node is JsonTextValue) {
if (node.value.startsWith(rootDefinition)) {
JsonTextValue("#/" + node.value.removePrefix(rootDefinition))
} else {
JsonTextValue("#/\$defs/" + node.value.removePrefix("#/definitions/"))
}
} else {
node
}
}
this.properties.values.forEach { it.fixRefsPrefix(rootDefinition) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,17 @@ package dev.restate.serde.kotlinx

import dev.restate.common.Slice
import dev.restate.serde.Serde
import dev.restate.serde.Serde.Schema
import dev.restate.serde.SerdeFactory
import dev.restate.serde.TypeRef
import dev.restate.serde.TypeTag
import java.nio.charset.StandardCharsets
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlinx.serialization.*
import kotlinx.serialization.builtins.*
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encodeToString
import kotlinx.serialization.builtins.nullable
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonTransformingSerializer
import kotlinx.serialization.modules.SerializersModule

/**
Expand All @@ -38,7 +32,22 @@ import kotlinx.serialization.modules.SerializersModule
*/
open class KotlinSerializationSerdeFactory
@JvmOverloads
constructor(private val json: Json = Json.Default) : SerdeFactory {
constructor(
private val json: Json = Json.Default,
private val jsonSchemaFactory: JsonSchemaFactory = DefaultJsonSchemaFactory
) : SerdeFactory {

/** Factory to generate json schemas. */
interface JsonSchemaFactory {
fun generateSchema(json: Json, serializer: KSerializer<*>): Schema?

companion object {
val NOOP =
object : JsonSchemaFactory {
override fun generateSchema(json: Json, serializer: KSerializer<*>): Schema? = null
}
}
}

@PublishedApi
internal class KtTypeTag<T>(
Expand All @@ -61,7 +70,7 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
}
val serializer: KSerializer<T> =
json.serializersModule.serializer(typeRef.type) as KSerializer<T>
return jsonSerde(json, serializer)
return jsonSerde(json, jsonSchemaFactory, serializer)
}

@Suppress("UNCHECKED_CAST")
Expand All @@ -70,7 +79,7 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
return UNIT as Serde<T>
}
val serializer: KSerializer<T> = json.serializersModule.serializer(clazz) as KSerializer<T>
return jsonSerde(json, serializer)
return jsonSerde(json, jsonSchemaFactory, serializer)
}

@Suppress("UNCHECKED_CAST")
Expand All @@ -81,7 +90,7 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
}
val serializer: KSerializer<T> =
json.serializersModule.serializerForKtTypeInfo(ktSerdeInfo) as KSerializer<T>
return jsonSerde(json, serializer)
return jsonSerde(json, jsonSchemaFactory, serializer)
}

companion object {
Expand All @@ -103,7 +112,13 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
}

/** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */
fun <T : Any?> jsonSerde(json: Json = Json.Default, serializer: KSerializer<T>): Serde<T> {
fun <T : Any?> jsonSerde(
json: Json = Json.Default,
jsonSchemaFactory: JsonSchemaFactory = DefaultJsonSchemaFactory,
serializer: KSerializer<T>
): Serde<T> {
val schema = jsonSchemaFactory.generateSchema(json, serializer)

return object : Serde<T> {
@Suppress("WRONG_NULLABILITY_FOR_JAVA_OVERRIDE")
override fun serialize(value: T?): Slice {
Expand All @@ -123,77 +138,11 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
return "application/json"
}

override fun jsonSchema(): Serde.Schema {
val schema: JsonSchema = serializer.descriptor.jsonSchema()
return Serde.StringifiedJsonSchema(Json.encodeToString(schema))
override fun jsonSchema(): Schema? {
return schema
}
}
}

@Serializable
@PublishedApi
internal data class JsonSchema(
@Serializable(with = StringListSerializer::class) val type: List<String>? = null,
val format: String? = null,
) {
companion object {
val INT = JsonSchema(type = listOf("number"), format = "int32")

val LONG = JsonSchema(type = listOf("number"), format = "int64")

val DOUBLE = JsonSchema(type = listOf("number"), format = "double")

val FLOAT = JsonSchema(type = listOf("number"), format = "float")

val STRING = JsonSchema(type = listOf("string"))

val BOOLEAN = JsonSchema(type = listOf("boolean"))

val OBJECT = JsonSchema(type = listOf("object"))

val LIST = JsonSchema(type = listOf("array"))

val ANY = JsonSchema()
}
}

object StringListSerializer :
JsonTransformingSerializer<List<String>>(ListSerializer(String.Companion.serializer())) {
override fun transformSerialize(element: JsonElement): JsonElement {
require(element is JsonArray)
return element.singleOrNull() ?: element
}
}

/**
* Super simplistic json schema generation. We should replace this with an appropriate library.
*/
@OptIn(ExperimentalSerializationApi::class)
@PublishedApi
internal fun SerialDescriptor.jsonSchema(): JsonSchema {
var schema =
when (this.kind) {
PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN
PrimitiveKind.BYTE -> JsonSchema.INT
PrimitiveKind.CHAR -> JsonSchema.STRING
PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE
PrimitiveKind.FLOAT -> JsonSchema.FLOAT
PrimitiveKind.INT -> JsonSchema.INT
PrimitiveKind.LONG -> JsonSchema.LONG
PrimitiveKind.SHORT -> JsonSchema.INT
PrimitiveKind.STRING -> JsonSchema.STRING
StructureKind.LIST -> JsonSchema.LIST
StructureKind.MAP -> JsonSchema.OBJECT
else -> JsonSchema.ANY
}

// Add nullability constraint
if (this.isNullable && schema.type != null) {
schema = schema.copy(type = schema.type.plus("null"))
}

return schema
}
}

@InternalSerializationApi
Expand Down
Loading