Skip to content

Commit

Permalink
Make use of Ktor websocket extensions and serialization
Browse files Browse the repository at this point in the history
Add Compression.
  • Loading branch information
DRSchlaubi committed Mar 29, 2023
1 parent 2a9be0f commit 7a0adff
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 316 deletions.
30 changes: 17 additions & 13 deletions core/src/main/kotlin/builder/kord/KordBuilderUtil.kt
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
package dev.kord.core.builder.kord

import dev.kord.common.annotation.KordUnsafe
import dev.kord.common.entity.Snowflake
import dev.kord.gateway.WebSocketCompression
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.websocket.*
import io.ktor.serialization.kotlinx.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.util.*
import kotlinx.serialization.json.Json

@OptIn(KordUnsafe::class)
internal fun HttpClientConfig<*>.defaultConfig() {
expectSuccess = false

val json = Json {
encodeDefaults = false
allowStructuredMapKeys = true
ignoreUnknownKeys = true
isLenient = true
}
install(ContentNegotiation) {
json()
json(json)
}
install(WebSockets) {
contentConverter = KotlinxWebsocketSerializationConverter(json)
extensions {
install(WebSocketCompression)
}
}
install(WebSockets)
}

internal fun HttpClient?.configure(): HttpClient {
if (this != null) return this.config {
defaultConfig()
}

val json = Json {
encodeDefaults = false
allowStructuredMapKeys = true
ignoreUnknownKeys = true
isLenient = true
}

return HttpClient(CIO) {
defaultConfig()
install(ContentNegotiation) {
json(json)
}
}
}

Expand Down
9 changes: 7 additions & 2 deletions gateway/src/main/kotlin/Command.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,31 @@ import dev.kord.common.serialization.InstantInEpochMillisecondsSerializer
import kotlinx.atomicfu.atomic
import kotlinx.datetime.Instant
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.SerializationStrategy as KSerializationStrategy

@Serializable(with = Command.SerializationStrategy::class)
public sealed class Command {

public data class Heartbeat(val sequenceNumber: Int?) : Command()

public object SerializationStrategy : KSerializationStrategy<Command> {
public object SerializationStrategy : KSerializer<Command> {

override val descriptor: SerialDescriptor = buildClassSerialDescriptor("Command") {
element("op", OpCode.serializer().descriptor)
element("d", JsonElement.serializer().descriptor)
}

override fun deserialize(decoder: Decoder): Command =
TODO("Deserializing gateway commands is not supported yet")

@OptIn(PrivilegedIntent::class)
override fun serialize(encoder: Encoder, value: Command) {
val composite = encoder.beginStructure(descriptor)
Expand Down
70 changes: 70 additions & 0 deletions gateway/src/main/kotlin/Compression.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package dev.kord.gateway

import dev.kord.common.annotation.KordUnsafe
import io.ktor.util.*
import io.ktor.websocket.*
import java.io.ByteArrayOutputStream
import java.util.zip.Inflater
import java.util.zip.InflaterOutputStream

/**
* [WebSocketExtension] inflating incoming websocket requests using `zlib`.
*
* *Note:** Normally you don't need this and this is configured by Kord automatically, however, if you want to use
* a custom HTTP client, you might need to add this, don't use it if you don't use what you're doing
*/
@KordUnsafe
public class WebSocketCompression : WebSocketExtension<Unit> {
/**
* https://discord.com/developers/docs/topics/gateway#transport-compression
*
* > Every connection to the gateway should use its own unique zlib context.
*
* https://api.ktor.io/ktor-shared/ktor-websockets/io.ktor.websocket/-web-socket-extension/index.html
* > A WebSocket extension instance. This instance is created for each WebSocket request,
* for every installed extension by WebSocketExtensionFactory.
*/
private val inflater = Inflater()

override val factory: WebSocketExtensionFactory<Unit, out WebSocketExtension<Unit>>
get() = Companion
override val protocols: List<WebSocketExtensionHeader>
get() = emptyList()

override fun clientNegotiation(negotiatedProtocols: List<WebSocketExtensionHeader>): Boolean = true

override fun processIncomingFrame(frame: Frame): Frame {
return if (frame is Frame.Binary) {
frame.deflateData()
} else {
frame
}
}

// Discord doesn't support deflating of gateway commands
override fun processOutgoingFrame(frame: Frame): Frame = frame

override fun serverNegotiation(requestedProtocols: List<WebSocketExtensionHeader>): List<WebSocketExtensionHeader> =
requestedProtocols

private fun Frame.deflateData(): Frame {
val outputStream = ByteArrayOutputStream()
InflaterOutputStream(outputStream, inflater).use {
it.write(data)
}

return outputStream.use {
val raw = String(outputStream.toByteArray(), 0, outputStream.size(), Charsets.UTF_8)
Frame.Text(raw)
}
}

public companion object : WebSocketExtensionFactory<Unit, WebSocketCompression> {
override val key: AttributeKey<WebSocketCompression> = AttributeKey("WebSocketCompression")
override val rsv1: Boolean = false
override val rsv2: Boolean = false
override val rsv3: Boolean = false

override fun install(config: Unit.() -> Unit): WebSocketCompression = WebSocketCompression()
}
}
82 changes: 12 additions & 70 deletions gateway/src/main/kotlin/DefaultGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@ import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.json.Json
import mu.KotlinLogging
import java.io.ByteArrayOutputStream
import java.util.zip.Inflater
import java.util.zip.InflaterOutputStream
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
Expand Down Expand Up @@ -78,13 +75,6 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

private val handshakeHandler: HandshakeHandler

private lateinit var inflater: Inflater

private val jsonParser = Json {
ignoreUnknownKeys = true
isLenient = true
}

private val stateMutex = Mutex()

init {
Expand Down Expand Up @@ -112,14 +102,9 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}

defaultGatewayLogger.trace { "opening gateway connection to $gatewayUrl" }
socket = data.client.webSocketSession { url(gatewayUrl) }

/**
* https://discord.com/developers/docs/topics/gateway#transport-compression
*
* > Every connection to the gateway should use its own unique zlib context.
*/
inflater = Inflater()
socket = data.client.webSocketSession {
url(gatewayUrl)
}
} catch (exception: Exception) {
defaultGatewayLogger.error(exception)
if (exception is java.nio.channels.UnresolvedAddressException) {
Expand Down Expand Up @@ -169,41 +154,12 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}


@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun readSocket() {
socket.incoming.asFlow().buffer(Channel.UNLIMITED).collect {
when (it) {
is Frame.Binary, is Frame.Text -> read(it)
else -> { /*ignore*/
}
}
}
}

private fun Frame.deflateData(): String {
val outputStream = ByteArrayOutputStream()
InflaterOutputStream(outputStream, inflater).use {
it.write(data)
}

return outputStream.use {
String(outputStream.toByteArray(), 0, outputStream.size(), Charsets.UTF_8)
}
}

private suspend fun read(frame: Frame) {
val json = when {
compression -> frame.deflateData()
else -> String(frame.data, Charsets.UTF_8)
}

try {
defaultGatewayLogger.trace { "Gateway <<< $json" }
val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json) ?: return
while (!socket.incoming.isClosedForReceive) {
val event = socket.receiveDeserialized<Event>()
data.eventFlow.emit(event)
} catch (exception: Exception) {
defaultGatewayLogger.error(exception)
}

}

private suspend fun handleClose() {
Expand All @@ -221,6 +177,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
state.update { State.Stopped }
throw IllegalStateException("Gateway closed: ${reason.code} ${reason.message}")
}

discordReason.resetSession -> {
setStopped()
}
Expand All @@ -232,14 +189,6 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
state.update { State.Running(true) }
}

private fun <T> ReceiveChannel<T>.asFlow() = flow {
try {
for (value in this@asFlow) emit(value)
} catch (ignore: CancellationException) {
//reading was stopped from somewhere else, ignore
}
}

override suspend fun stop() {
check(state.value !is State.Detached) { "The resources of this gateway are detached, create another one" }
data.eventFlow.emit(Close.UserClose)
Expand Down Expand Up @@ -280,14 +229,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

private suspend fun sendUnsafe(command: Command) {
data.sendRateLimiter.consume()
val json = Json.encodeToString(Command.SerializationStrategy, command)
if (command is Identify) {
defaultGatewayLogger.trace {
val copy = command.copy(token = "token")
"Gateway >>> ${Json.encodeToString(Command.SerializationStrategy, copy)}"
}
} else defaultGatewayLogger.trace { "Gateway >>> $json" }
socket.send(Frame.Text(json))
socket.sendSerialized(command)
}

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down
15 changes: 10 additions & 5 deletions gateway/src/main/kotlin/DefaultGatewayBuilder.kt
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package dev.kord.gateway

import dev.kord.common.KordConfiguration
import dev.kord.common.annotation.KordUnsafe
import dev.kord.common.ratelimit.IntervalRateLimiter
import dev.kord.common.ratelimit.RateLimiter
import dev.kord.gateway.ratelimit.IdentifyRateLimiter
import dev.kord.gateway.retry.LinearRetry
import dev.kord.gateway.retry.Retry
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.serialization.kotlinx.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.serialization.json.Json
import kotlin.time.Duration.Companion.seconds

public class DefaultGatewayBuilder {
Expand All @@ -28,11 +29,15 @@ public class DefaultGatewayBuilder {
public var dispatcher: CoroutineDispatcher = Dispatchers.Default
public var eventFlow: MutableSharedFlow<Event> = MutableSharedFlow(extraBufferCapacity = Int.MAX_VALUE)

@OptIn(KordUnsafe::class)
public fun build(): DefaultGateway {
val client = client ?: HttpClient(CIO) {
install(WebSockets)
install(ContentNegotiation) {
json()
install(WebSockets) {
contentConverter = KotlinxWebsocketSerializationConverter(Json)

extensions {
install(WebSocketCompression)
}
}
}
val retry = reconnectRetry ?: LinearRetry(2.seconds, 20.seconds, 10)
Expand Down
Loading

0 comments on commit 7a0adff

Please sign in to comment.