Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ class GraphQL(config: GraphQLConfiguration) {
subscriptionHooks = config.server.subscriptions.hooks,
requestHandler = requestHandler,
initTimeoutMillis = config.server.subscriptions.connectionInitTimeout,
objectMapper = jacksonMapperBuilder().apply(config.server.jacksonConfiguration).build()
objectMapper = jacksonMapperBuilder().apply(config.server.jacksonConfiguration).build(),
subscriptionConcurrency = config.server.subscriptions.subscriptionConcurrency
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.expediagroup.graphql.generator.hooks.FlowSubscriptionSchemaGeneratorH
import com.expediagroup.graphql.generator.hooks.SchemaGeneratorHooks
import com.expediagroup.graphql.generator.scalars.IDValueUnboxer
import com.expediagroup.graphql.server.Schema
import com.expediagroup.graphql.server.execution.subscription.DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
import com.expediagroup.graphql.server.ktor.subscriptions.DefaultKtorGraphQLSubscriptionContextFactory
import com.expediagroup.graphql.server.ktor.subscriptions.DefaultKtorGraphQLSubscriptionHooks
import com.expediagroup.graphql.server.ktor.subscriptions.DefaultKtorGraphQLSubscriptionRequestParser
Expand Down Expand Up @@ -284,6 +285,14 @@ class GraphQLConfiguration(config: ApplicationConfig) {
var hooks: KtorGraphQLSubscriptionHooks = DefaultKtorGraphQLSubscriptionHooks()
/** Server timeout between establishing web socket connection and receiving connection-init message */
var connectionInitTimeout: Long = config.tryGetString("graphql.server.subscription.connectionInitTimeout")?.toLongOrNull() ?: 60_000
/**
* Maximum number of inbound client messages processed concurrently per web socket session. Defaults to
* [DEFAULT_WS_SUBSCRIPTION_CONCURRENCY] (16). Raise this when a single session may hold more than the default
* number of simultaneous subscriptions, otherwise additional messages (including ping/complete/subscribe)
* are back-pressured until one of the in-flight messages completes.
*/
var subscriptionConcurrency: Int =
config.tryGetString("graphql.server.subscription.concurrency")?.toIntOrNull() ?: DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.expediagroup.graphql.server.ktor.subscriptions

import com.expediagroup.graphql.server.execution.GraphQLRequestHandler
import com.expediagroup.graphql.server.execution.subscription.DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
import com.expediagroup.graphql.server.execution.subscription.GraphQLWebSocketServer
import com.expediagroup.graphql.server.types.GraphQLSubscriptionStatus
import io.ktor.server.websocket.WebSocketServerSession
Expand All @@ -36,9 +37,10 @@ class KtorGraphQLWebSocketServer(
subscriptionHooks: KtorGraphQLSubscriptionHooks,
requestHandler: GraphQLRequestHandler,
initTimeoutMillis: Long,
objectMapper: ObjectMapper
objectMapper: ObjectMapper,
subscriptionConcurrency: Int = DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
) : GraphQLWebSocketServer<WebSocketServerSession, Unit>(
requestParser, contextFactory, subscriptionHooks, requestHandler, initTimeoutMillis, objectMapper
requestParser, contextFactory, subscriptionHooks, requestHandler, initTimeoutMillis, objectMapper, subscriptionConcurrency
) {
override suspend fun closeSession(session: WebSocketServerSession, reason: GraphQLSubscriptionStatus) {
session.close(CloseReason(reason.code.toShort(), reason.reason))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,21 @@ import kotlin.coroutines.EmptyCoroutineContext
const val GRAPHQL_WS_PROTOCOL = "graphql-transport-ws"

/**
* GraphQL Web Socket server implementation for handling subscriptions using *graphql-transport-ws* protocol
* Default maximum number of in-flight inbound messages processed concurrently by a single web socket session.
*
* Matches the historical default of `kotlinx.coroutines.flow.flatMapMerge` (16) that was used before this value
* was made configurable. Raising the limit lets a single session process more simultaneous subscriptions at the
* cost of higher peak memory; `Int.MAX_VALUE` effectively removes the limit; `1` serializes message processing.
*/
const val DEFAULT_WS_SUBSCRIPTION_CONCURRENCY: Int = 16

/**
* GraphQL Web Socket server implementation for handling subscriptions using *graphql-transport-ws* protocol.
*
* @param subscriptionConcurrency maximum number of inbound client messages processed concurrently by the
* `flatMapMerge` operator that drives [handleSubscription]. At the default ([DEFAULT_WS_SUBSCRIPTION_CONCURRENCY])
* in-flight subscriptions can back-pressure sibling protocol messages (ping, complete, new subscribe) once the
* ceiling is hit; see issue #2018. Pass a larger value, or `Int.MAX_VALUE`, to avoid this.
*
* @see <a href="https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md">graphql-transport-ws protocol</a>
*/
Expand All @@ -67,7 +81,8 @@ abstract class GraphQLWebSocketServer<Session, Message>(
private val subscriptionHooks: GraphQLSubscriptionHooks<Session>,
private val requestHandler: GraphQLRequestHandler,
private val initTimeoutMillis: Long,
private val objectMapper: ObjectMapper = jacksonObjectMapper()
private val objectMapper: ObjectMapper = jacksonObjectMapper(),
private val subscriptionConcurrency: Int = DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
) {
private val logger: Logger = LoggerFactory.getLogger(GraphQLWebSocketServer::class.java)
private val subscriptionScope = CoroutineScope(SupervisorJob())
Expand All @@ -86,7 +101,7 @@ abstract class GraphQLWebSocketServer<Session, Message>(

requestParser.parseRequestFlow(session)
.map { objectMapper.readValue<GraphQLSubscriptionMessage>(it) }
.flatMapMerge { message ->
.flatMapMerge(concurrency = subscriptionConcurrency) { message ->
channelFlow {
when (message) {
is SubscriptionMessageConnectionInit -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,60 @@ class GraphQLWebSocketServerTest {
subscriptionJob.cancelAndJoin()
}

@Test
fun `verify subscription flow honors configured concurrency`() = runTest {
// concurrency=1 serializes inbound message processing: each subscribe's channelFlow must fully complete
// (counter 1,2,3 + complete) before the next subscribe starts flowing. With the historical default (16)
// the two subscriptions would interleave. See issue #2018.
val handler = GraphQLRequestHandler(graphQL = testGraphQLEngine())
val testServer = InMemoryGraphQLSubscriptionServer(
requestHandler = handler,
subscriptionConcurrency = 1
)

val session = Channel<String>(Channel.BUFFERED)
val responseChannel = testServer.outboundChannel

val subscriptionJob = launch {
testServer.handleSubscription(session)
.collect()
}

session.send(mapper.writeValueAsString(SubscriptionMessageConnectionInit()))
val ack: GraphQLSubscriptionMessage = mapper.readValue(responseChannel.receive())
assertEquals(GRAPHQL_WS_CONNECTION_ACK, ack.type)

val firstId = UUID.randomUUID().toString()
val secondId = UUID.randomUUID().toString()
val request = GraphQLRequest(query = "subscription { counter }")
session.send(mapper.writeValueAsString(SubscriptionMessageSubscribe(id = firstId, payload = request)))
session.send(mapper.writeValueAsString(SubscriptionMessageSubscribe(id = secondId, payload = request)))

// With concurrency=1 the second subscribe is held behind the first, so every response for firstId
// (3 next + 1 complete) must arrive before any response for secondId.
val firstResponseIds = (1..4).map {
val msg: GraphQLSubscriptionMessage = mapper.readValue(responseChannel.receive())
when (msg) {
is SubscriptionMessageNext -> msg.id
is SubscriptionMessageComplete -> msg.id
else -> error("unexpected message type: $msg")
}
}
assertTrue(firstResponseIds.all { it == firstId }, "expected all first-batch ids == $firstId but got $firstResponseIds")

val secondResponseIds = (1..4).map {
val msg: GraphQLSubscriptionMessage = mapper.readValue(responseChannel.receive())
when (msg) {
is SubscriptionMessageNext -> msg.id
is SubscriptionMessageComplete -> msg.id
else -> error("unexpected message type: $msg")
}
}
assertTrue(secondResponseIds.all { it == secondId }, "expected all second-batch ids == $secondId but got $secondResponseIds")

subscriptionJob.cancelAndJoin()
}

private fun testGraphQLEngine(): GraphQL = GraphQL.newGraphQL(
toSchema(
config = SchemaGeneratorConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@ class InMemoryGraphQLSubscriptionServer(
requestParser: InMemorySubscriptionRequestParser = InMemorySubscriptionRequestParser(),
contextFactory: InMemorySubscriptionContextFactory = InMemorySubscriptionContextFactory(),
hooks: InMemorySubscriptionHooks = InMemorySubscriptionHooks(),
timeoutInMillis: Long = 1000
timeoutInMillis: Long = 1000,
subscriptionConcurrency: Int = DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
) : GraphQLWebSocketServer<Channel<String>, String>(
requestParser, contextFactory, hooks, requestHandler, timeoutInMillis
requestParser = requestParser,
contextFactory = contextFactory,
subscriptionHooks = hooks,
requestHandler = requestHandler,
initTimeoutMillis = timeoutInMillis,
subscriptionConcurrency = subscriptionConcurrency
) {
val outboundChannel = Channel<String>(Channel.BUFFERED)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.expediagroup.graphql.server.spring

import com.expediagroup.graphql.server.execution.subscription.DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
import org.springframework.boot.context.properties.ConfigurationProperties
import org.springframework.boot.context.properties.NestedConfigurationProperty

Expand Down Expand Up @@ -91,7 +92,14 @@ data class GraphQLConfigurationProperties(
/** Server timeout between establishing web socket connection and receiving connection-init message. */
val connectionInitTimeout: Long = 60_000,
/** WebSocket based subscription protocol */
val protocol: SubscriptionProtocol = SubscriptionProtocol.GRAPHQL_WS
val protocol: SubscriptionProtocol = SubscriptionProtocol.GRAPHQL_WS,
/**
* Maximum number of inbound client messages processed concurrently per web socket session. Defaults to
* [DEFAULT_WS_SUBSCRIPTION_CONCURRENCY] (16). Raise this when a single session may hold more than the
* default number of simultaneous subscriptions, otherwise additional messages (including ping/complete/
* subscribe) are back-pressured until one of the in-flight messages completes.
*/
val subscriptionConcurrency: Int = DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class SubscriptionGraphQLWsAutoConfiguration {
subscriptionHooks,
handler,
config.subscriptions.connectionInitTimeout,
objectMapper
objectMapper,
config.subscriptions.subscriptionConcurrency
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.expediagroup.graphql.server.spring.subscriptions

import com.expediagroup.graphql.server.execution.GraphQLRequestHandler
import com.expediagroup.graphql.server.execution.subscription.DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
import com.expediagroup.graphql.server.execution.subscription.GRAPHQL_WS_PROTOCOL
import com.expediagroup.graphql.server.execution.subscription.GraphQLWebSocketServer
import com.expediagroup.graphql.server.types.GraphQLSubscriptionStatus
Expand All @@ -40,9 +41,10 @@ class SubscriptionWebSocketHandler(
subscriptionHooks: SpringGraphQLSubscriptionHooks,
graphqlHandler: GraphQLRequestHandler,
initTimeoutMillis: Long,
objectMapper: ObjectMapper
objectMapper: ObjectMapper,
subscriptionConcurrency: Int = DEFAULT_WS_SUBSCRIPTION_CONCURRENCY
) : WebSocketHandler, GraphQLWebSocketServer<WebSocketSession, WebSocketMessage>(
requestParser, contextFactory, subscriptionHooks, graphqlHandler, initTimeoutMillis, objectMapper
requestParser, contextFactory, subscriptionHooks, graphqlHandler, initTimeoutMillis, objectMapper, subscriptionConcurrency
) {
override fun handle(session: WebSocketSession): Mono<Void> = session.send(
flux {
Expand Down
Loading