Skip to content

Commit f2c22af

Browse files
authored
Add an encoded discriminator value attribute for coproducts, use it to render const constraints (#3955)
1 parent 4822f59 commit f2c22af

13 files changed

Lines changed: 209 additions & 21 deletions

File tree

core/src/main/scala/sttp/tapir/Schema.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,26 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros {
361361
val Attribute: AttributeKey[Tuple] = new AttributeKey[Tuple]("sttp.tapir.Schema.Tuple")
362362
}
363363

364+
/** For coproduct schemas, when there's a discriminator field, used to attach the encoded value of the discriminator field. Such value is
365+
* added to the discriminator field schemas in each of the coproduct's subtypes. When rendering OpenAPI/JSON schema, these values are
366+
* converted to `const` constraints on fields.
367+
*/
368+
case class EncodedDiscriminatorValue(v: String)
369+
object EncodedDiscriminatorValue {
370+
/*
371+
Implementation note: the discriminator value constraint is in fact an enum validator with a single possible enum value. Hence an
372+
alternative design would be to add such validators to discriminator fields, instead of an attribute. However, this has two drawbacks:
373+
1. when adding discriminator fields using `addDiscriminatorField`, we don't have access to the decoded discriminator value - only
374+
to the encoded one, via reverse mapping lookup
375+
2. the validator doesn't necessarily make sense, as it can't be used to validate the deserialiszd object. Usually the discriminator
376+
fields don't even exist on the high-level representations.
377+
That's why instead of re-using the validators, we decided to use a specialised attribute.
378+
*/
379+
380+
val Attribute: AttributeKey[EncodedDiscriminatorValue] =
381+
new AttributeKey[EncodedDiscriminatorValue]("sttp.tapir.Schema.EncodedDiscriminatorValue")
382+
}
383+
364384
/** @param typeParameterShortNames
365385
* full name of type parameters, name is legacy and kept only for backward compatibility
366386
*/

core/src/main/scala/sttp/tapir/SchemaType.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,39 @@ object SchemaType {
129129
discriminatorSchema: Schema[D] = Schema.string,
130130
discriminatorMapping: Map[String, SRef[_]] = Map.empty
131131
): SCoproduct[T] = {
132+
// used to add encoded discriminator value attributes
133+
val reverseDiscriminatorByNameMapping: Map[SName, String] = discriminatorMapping.toList.map { case (v, ref) => (ref.name, v) }.toMap
134+
132135
SCoproduct(
133136
subtypes.map {
134-
case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _)
135-
if st.fields.forall(_.name != discriminatorName) =>
136-
s.copy(schemaType = st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None)))
137+
case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _) =>
138+
// first, ensuring that the discriminator field is added to the schema type - it might already be present
139+
var targetSt =
140+
if (st.fields.forall(_.name != discriminatorName))
141+
st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None))
142+
else st
143+
144+
// next, modifying the discriminator field, by adding the value attribute (if a value can be found)
145+
targetSt = targetSt.copy(fields = targetSt.fields.map { field =>
146+
if (field.name == discriminatorName) {
147+
val discriminatorValue = s.name.flatMap { subtypeName =>
148+
reverseDiscriminatorByNameMapping.get(subtypeName)
149+
}
150+
151+
discriminatorValue match {
152+
case Some(v) =>
153+
SProductField(
154+
field.name,
155+
field.schema.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(v)),
156+
field.get
157+
)
158+
case None => field
159+
}
160+
161+
} else field
162+
})
163+
164+
s.copy(schemaType = targetSt)
137165
case s => s
138166
},
139167
Some(SDiscriminator(discriminatorName, discriminatorMapping))

core/src/test/scala/sttp/tapir/SchemaMacroTest.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,14 @@ class SchemaMacroTest extends AnyFlatSpec with Matchers with TableDrivenProperty
304304

305305
schemaType.subtypes.foreach { childSchema =>
306306
val childProduct = childSchema.schemaType.asInstanceOf[SProduct[_]]
307-
childProduct.fields.find(_.name.name == "kind") shouldBe Some(SProductField(FieldName("kind"), Schema.string, (_: Any) => None))
307+
val discValue = if (childSchema.name.get.fullName == "sttp.tapir.SchemaMacroTestData.User") "user" else "org"
308+
childProduct.fields.find(_.name.name == "kind") shouldBe Some(
309+
SProductField(
310+
FieldName("kind"),
311+
Schema.string.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(discValue)),
312+
(_: Any) => None
313+
)
314+
)
308315
}
309316
}
310317

core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,13 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers {
245245
schemaType.asInstanceOf[SCoproduct[Entity]].subtypes should contain theSameElementsAs List(
246246
Schema(
247247
SProduct[Organization](
248-
List(field(FieldName("name"), Schema(SString())), field(FieldName("who_am_i"), Schema(SString())))
248+
List(
249+
field(FieldName("name"), Schema(SString())),
250+
field(
251+
FieldName("who_am_i"),
252+
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Organization"))
253+
)
254+
)
249255
),
250256
Some(SName("sttp.tapir.generic.Organization"))
251257
),
@@ -254,15 +260,21 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers {
254260
List(
255261
field(FieldName("first"), Schema(SString())),
256262
field(FieldName("age"), Schema(SInteger(), format = Some("int32"))),
257-
field(FieldName("who_am_i"), Schema(SString()))
263+
field(
264+
FieldName("who_am_i"),
265+
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Person"))
266+
)
258267
)
259268
),
260269
Some(SName("sttp.tapir.generic.Person"))
261270
),
262271
Schema(
263272
SProduct[UnknownEntity.type](
264273
List(
265-
field(FieldName("who_am_i"), Schema(SString()))
274+
field(
275+
FieldName("who_am_i"),
276+
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("UnknownEntity"))
277+
)
266278
)
267279
),
268280
Some(SName("sttp.tapir.generic.UnknownEntity"))

docs/apispec-docs/src/main/scala/sttp/tapir/docs/apispec/schema/TSchemaToASchema.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,13 @@ private[docs] class TSchemaToASchema(
112112
// The primary motivation for using schema name as fallback title is to improve Swagger UX with
113113
// `oneOf` schemas in OpenAPI 3.1. See https://github.com/softwaremill/tapir/issues/3447 for details.
114114
def fallbackTitle = tschema.name.map(fallbackSchemaTitle)
115+
116+
val const = tschema.attribute(TSchema.EncodedDiscriminatorValue.Attribute).map(_.v).map(v => ExampleSingleValue(v))
117+
115118
oschema
116-
.copy(title = titleFromAttr orElse fallbackTitle)
119+
.copy(title = titleFromAttr.orElse(fallbackTitle))
117120
.copy(uniqueItems = tschema.attribute(UniqueItems.Attribute).map(_.uniqueItems))
121+
.copy(const = const)
118122
}
119123

120124
private def addMetadata(oschema: ASchema, tschema: TSchema[_]): ASchema = {
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
asyncapi: 2.6.0
2+
info:
3+
title: discriminator
4+
version: '1.0'
5+
channels:
6+
/animals:
7+
subscribe:
8+
operationId: onAnimals
9+
message:
10+
$ref: '#/components/messages/Animal'
11+
publish:
12+
operationId: sendAnimals
13+
message:
14+
$ref: '#/components/messages/GetAnimal'
15+
bindings:
16+
ws:
17+
method: GET
18+
components:
19+
schemas:
20+
GetAnimal:
21+
title: GetAnimal
22+
type: object
23+
required:
24+
- name
25+
properties:
26+
name:
27+
type: string
28+
Animal:
29+
title: Animal
30+
oneOf:
31+
- $ref: '#/components/schemas/Cat'
32+
- $ref: '#/components/schemas/Dog'
33+
discriminator: pet
34+
Cat:
35+
title: Cat
36+
type: object
37+
required:
38+
- name
39+
- pet
40+
properties:
41+
name:
42+
type: string
43+
pet:
44+
type: string
45+
const: Cat
46+
Dog:
47+
title: Dog
48+
type: object
49+
required:
50+
- name
51+
- breed
52+
- pet
53+
properties:
54+
name:
55+
type: string
56+
breed:
57+
type: string
58+
pet:
59+
type: string
60+
const: Dog
61+
messages:
62+
GetAnimal:
63+
payload:
64+
$ref: '#/components/schemas/GetAnimal'
65+
contentType: application/json
66+
Animal:
67+
payload:
68+
$ref: '#/components/schemas/Animal'
69+
contentType: application/json

docs/asyncapi-docs/src/test/scala/sttp/tapir/docs/asyncapi/VerifyAsyncAPIYamlTest.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,11 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers {
133133
.out(
134134
webSocketBody[Fruit, CodecFormat.Json, Int, CodecFormat.Json](AkkaStreams)
135135
// TODO: missing `RequestInfo.example(example: EndpointIO.Example)` and friends
136-
.pipe(e => e.copy(requestsInfo = e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple"))))
136+
.pipe(e =>
137+
e.copy(requestsInfo =
138+
e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple"))
139+
)
140+
)
137141
)
138142

139143
val expectedYaml = loadYaml("expected_json_example_name_summary.yml")
@@ -232,6 +236,22 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers {
232236
noIndentation(yaml) shouldBe loadYaml("expected_flags_header.yml")
233237
}
234238

239+
test("should work with discriminators") {
240+
case class GetAnimal(name: String)
241+
sealed trait Animal
242+
case class Cat(name: String) extends Animal
243+
case class Dog(name: String, breed: String) extends Animal
244+
implicit val configuration: sttp.tapir.generic.Configuration = sttp.tapir.generic.Configuration.default.withDiscriminator("pet")
245+
246+
val animalEndpoint = endpoint.get
247+
.in("animals")
248+
.out(webSocketBody[GetAnimal, CodecFormat.Json, Animal, CodecFormat.Json](AkkaStreams))
249+
250+
val yaml = AsyncAPIInterpreter().toAsyncAPI(animalEndpoint, "discriminator", "1.0").toYaml
251+
252+
noIndentation(yaml) shouldBe loadYaml("expected_coproduct_with_discriminator.yml")
253+
}
254+
235255
private def loadYaml(fileName: String): String = {
236256
noIndentation(Source.fromInputStream(getClass.getResourceAsStream(s"/$fileName")).getLines().mkString("\n"))
237257
}

docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ components:
3333
properties:
3434
name:
3535
type: string
36+
const: sml
3637
Person:
3738
title: Person
3839
type: object
@@ -42,6 +43,7 @@ components:
4243
properties:
4344
name:
4445
type: string
46+
const: john
4547
age:
4648
type: integer
4749
format: int32

docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_nested.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ components:
4141
properties:
4242
name:
4343
type: string
44+
const: sml
4445
Person:
4546
title: Person
4647
type: object
@@ -50,6 +51,7 @@ components:
5051
properties:
5152
name:
5253
type: string
54+
const: john
5355
age:
5456
type: integer
5557
format: int32

docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_with_enum_circe.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ components:
3737
- red
3838
shapeType:
3939
type: string
40+
const: Square

0 commit comments

Comments
 (0)