Skip to content

Commit

Permalink
Migrate ktor websocket extension
Browse files Browse the repository at this point in the history
  • Loading branch information
ryoii committed Oct 7, 2023
1 parent 1d8c5c7 commit e1b636a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import kotlinx.coroutines.launch
import net.mamoe.mirai.api.http.adapter.reverse.Destination
import net.mamoe.mirai.api.http.adapter.reverse.ReverseWebsocketAdapterSetting
import net.mamoe.mirai.api.http.adapter.reverse.handleReverseWs
import net.mamoe.mirai.api.http.adapter.ws.extension.FrameLogExtension
import net.mamoe.mirai.api.http.context.MahContextHolder
import net.mamoe.mirai.utils.MiraiLogger
import net.mamoe.mirai.utils.warning
import kotlin.coroutines.CoroutineContext
Expand All @@ -35,7 +37,13 @@ class WsClient(private var log: MiraiLogger) : CoroutineScope {
var bindingSessionKey: String? = null

private val client = HttpClient {
install(WebSockets)
install(WebSockets) {
extensions {
if (MahContextHolder.debug) {
install(FrameLogExtension)
}
}
}
}

private var webSocketSession: DefaultClientWebSocketSession? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,41 @@ package net.mamoe.mirai.api.http.adapter.ws.extension

import io.ktor.util.*
import io.ktor.websocket.*
import net.mamoe.mirai.api.http.adapter.internal.serializer.jsonParseOrNull
import net.mamoe.mirai.api.http.adapter.ws.dto.WsIncoming
import net.mamoe.mirai.utils.MiraiLogger

class FrameLogExtension(configuration: Configuration) :
WebSocketExtension<FrameLogExtension.Configuration> {
class FrameLogExtension: WebSocketExtension<Unit> {

private val logger = configuration.logger.value
private val enable = configuration.enableAccessLog
private val logger = MiraiLogger.Factory.create(FrameLogExtension::class, "MAH Access")

override val factory = FrameLogExtension
override val protocols = emptyList<WebSocketExtensionHeader>()

override fun clientNegotiation(negotiatedProtocols: List<WebSocketExtensionHeader>): Boolean {

return true
}

override fun serverNegotiation(requestedProtocols: List<WebSocketExtensionHeader>): List<WebSocketExtensionHeader> {
return emptyList()
return listOf(WebSocketExtensionHeader("frame-log", emptyList()))
}

override fun processIncomingFrame(frame: Frame): Frame {
if (enable) {
val commandWrapper = String(frame.data).jsonParseOrNull<WsIncoming>() ?: return frame
logger.debug("[incoming] $commandWrapper")
}
logger.debug("[incoming] ${(frame as Frame.Text).readText()})")
return frame
}

override fun processOutgoingFrame(frame: Frame): Frame {
return frame
}

class Configuration {
var logger = lazy { MiraiLogger.Factory.create(FrameLogExtension::class, "MAH Access") }
var enableAccessLog = false
}

companion object : WebSocketExtensionFactory<Configuration, FrameLogExtension> {
companion object : WebSocketExtensionFactory<Unit, FrameLogExtension> {
override val key = AttributeKey<FrameLogExtension>("FRAME LOG")

override val rsv1: Boolean = false
override val rsv2: Boolean = false
override val rsv3: Boolean = false

override fun install(config: Configuration.() -> Unit): FrameLogExtension {
return FrameLogExtension(Configuration().apply(config))
override fun install(config: Unit.() -> Unit): FrameLogExtension {
return FrameLogExtension()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import net.mamoe.mirai.api.http.context.MahContextHolder
*/
fun Application.websocketRouteModule(wsAdapter: WebsocketAdapter) {
install(WebSockets) {
extensions {
install(FrameLogExtension) { enableAccessLog = MahContextHolder.debug }
extensions {
if (MahContextHolder.debug) {
install(FrameLogExtension)
}
}
}
wsRouter(wsAdapter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

package net.mamoe.mirai.api.http.adapter.ws.router

import io.ktor.server.application.*
import io.ktor.server.routing.*
import io.ktor.server.websocket.*
import io.ktor.util.*
Expand All @@ -19,7 +18,6 @@ import net.mamoe.mirai.api.http.adapter.common.StateCode
import net.mamoe.mirai.api.http.adapter.internal.serializer.toJson
import net.mamoe.mirai.api.http.adapter.internal.serializer.toJsonElement
import net.mamoe.mirai.api.http.adapter.ws.dto.WsOutgoing
import net.mamoe.mirai.api.http.adapter.ws.extension.FrameLogExtension
import net.mamoe.mirai.api.http.context.MahContextHolder


Expand All @@ -33,9 +31,6 @@ internal inline fun Route.miraiWebsocket(
val sessionKey = call.request.headers["sessionKey"] ?: call.parameters["sessionKey"]
val qq = (call.request.headers["qq"] ?: call.parameters["qq"])?.toLongOrNull()

// 注入无协商的扩展
installExtension(FrameLogExtension)

// 校验
if (MahContextHolder.enableVerify && MahContextHolder.sessionManager.verifyKey != verifyKey) {
closeWithCode(StateCode.AuthKeyFail)
Expand Down Expand Up @@ -93,10 +88,3 @@ internal suspend fun DefaultWebSocketServerSession.closeWithCode(code: StateCode
))
close(CloseReason(CloseReason.Codes.NORMAL, code.msg))
}


internal fun <T: WebSocketExtension<*>> WebSocketServerSession.installExtension(factory: WebSocketExtensionFactory<*, T>) {
application.plugin(WebSockets).extensionsConfig.build().find { it.factory.key == factory.key }?.let {
(extensions as MutableList<WebSocketExtension<*>>).add(it)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
*
* * Copyright 2023 Mamoe Technologies and contributors.
* *
* * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
* *
* * https://github.com/mamoe/mirai/blob/master/LICENSE
*
*/

package net.mamoe.mirai.api.http.adapter.ws.extension

import io.ktor.client.plugins.websocket.*
import io.ktor.server.testing.*
import io.ktor.server.websocket.*
import io.ktor.server.websocket.WebSockets
import io.ktor.websocket.*
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull

class FrameLogExtensionTest {

@Test
fun testFrameLogExtension() = testApplication {
install(WebSockets) {
extensions {
install(FrameLogExtension)
}
}

routing {
webSocket("/echo") {
assertNotNull(extensionOrNull(FrameLogExtension))
for (frame in incoming) {
send(frame)
}
}
}

val wsClient = createClient { install(io.ktor.client.plugins.websocket.WebSockets) }
wsClient.ws("/echo") {
outgoing.send(Frame.Text("Hello"))

val receive = incoming.receive()
assertEquals(FrameType.TEXT, receive.frameType)
assertEquals("Hello", (receive as Frame.Text).readText())
}
}
}

0 comments on commit e1b636a

Please sign in to comment.