From f4c523c315a18a69213ef13d7ebea40c08e838e0 Mon Sep 17 00:00:00 2001 From: "He-Pin(kerr)" Date: Mon, 20 May 2024 02:45:23 +0800 Subject: [PATCH] feat: Add virtual threads support. --- .../internal/NewThreadPerTaskExecutor.scala | 14 ++++ .../VirtualThreadBlockingHandler.scala | 16 ++++ .../cask/internal/VirtualThreadSupport.scala | 31 ++++++++ cask/src/cask/main/Main.scala | 74 +++++++++++++------ example/compress/app/src/Compress.scala | 3 +- 5 files changed, 116 insertions(+), 22 deletions(-) create mode 100644 cask/src/cask/internal/NewThreadPerTaskExecutor.scala create mode 100644 cask/src/cask/internal/VirtualThreadBlockingHandler.scala create mode 100644 cask/src/cask/internal/VirtualThreadSupport.scala diff --git a/cask/src/cask/internal/NewThreadPerTaskExecutor.scala b/cask/src/cask/internal/NewThreadPerTaskExecutor.scala new file mode 100644 index 0000000000..a0988f47ed --- /dev/null +++ b/cask/src/cask/internal/NewThreadPerTaskExecutor.scala @@ -0,0 +1,14 @@ +package cask.internal + +import java.util.concurrent.{Executor, ThreadFactory} + +private[cask] final class NewThreadPerTaskExecutor(val threadFactory: ThreadFactory) + extends Executor { + override def execute(command: Runnable): Unit = { + val thread = threadFactory.newThread(command) + thread.start() + if (thread.getState eq Thread.State.TERMINATED) { + throw new IllegalStateException("Thread has already been terminated") + } + } +} diff --git a/cask/src/cask/internal/VirtualThreadBlockingHandler.scala b/cask/src/cask/internal/VirtualThreadBlockingHandler.scala new file mode 100644 index 0000000000..2c3ac70b22 --- /dev/null +++ b/cask/src/cask/internal/VirtualThreadBlockingHandler.scala @@ -0,0 +1,16 @@ +package cask.internal + +import io.undertow.server.{HttpHandler, HttpServerExchange} + +private[cask] final class VirtualThreadBlockingHandler(val handler: HttpHandler) + extends HttpHandler { + override def handleRequest(exchange: HttpServerExchange): Unit = { + exchange.startBlocking() + exchange.dispatch(VirtualThreadBlockingHandler.EXECUTOR, handler) + } +} + +private[cask] object VirtualThreadBlockingHandler { + private lazy val EXECUTOR = new NewThreadPerTaskExecutor( + VirtualThreadSupport.create("cask")) +} diff --git a/cask/src/cask/internal/VirtualThreadSupport.scala b/cask/src/cask/internal/VirtualThreadSupport.scala new file mode 100644 index 0000000000..bb9ceddc08 --- /dev/null +++ b/cask/src/cask/internal/VirtualThreadSupport.scala @@ -0,0 +1,31 @@ +package cask.internal + +import java.lang.invoke.{MethodHandles, MethodType} +import java.util.concurrent.ThreadFactory + +private[cask] object VirtualThreadSupport { + + /** + * Returns if the current Runtime supports virtual threads. + */ + lazy val isVirtualThreadSupported: Boolean = create("testIfSupported") ne null + + /** + * Create a virtual thread factory, returns null when failed. + */ + def create(prefix: String): ThreadFactory = + try { + val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") + val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") + val lookup = MethodHandles.lookup + val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass)) + var builder = ofVirtualMethod.invoke() + val nameMethod = lookup.findVirtual(ofVirtualClass, "name", + MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])) + val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory])) + builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L) + factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] + } catch { + case _: Throwable => null + } +} diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index c11a7d8094..a5757804d5 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -2,7 +2,7 @@ package cask.main import cask.endpoints.{WebsocketResult, WsHandler} import cask.model._ -import cask.internal.{DispatchTrie, Util} +import cask.internal.{DispatchTrie, Util, VirtualThreadBlockingHandler, VirtualThreadSupport} import cask.model.Response.Raw import cask.router.{Decorator, EndpointMetadata, EntryPoint, Result} import cask.util.Logger @@ -16,41 +16,69 @@ import scala.concurrent.ExecutionContext import scala.util.control.NonFatal /** - * A combination of [[cask.Main]] and [[cask.Routes]], ideal for small - * one-file web applications. - */ -class MainRoutes extends Main with Routes{ + * A combination of [[cask.Main]] and [[cask.Routes]], ideal for small + * one-file web applications. + */ +class MainRoutes extends Main with Routes { def allRoutes: Seq[Routes] = Seq(this) } /** - * Defines the main entrypoint and configuration of the Cask web application. - * - * You can pass in an arbitrary number of [[cask.Routes]] objects for it to - * serve, and override various properties on [[Main]] in order to configure - * application-wide properties. - */ -abstract class Main{ + * Defines the main entrypoint and configuration of the Cask web application. + * + * You can pass in an arbitrary number of [[cask.Routes]] objects for it to + * serve, and override various properties on [[Main]] in order to configure + * application-wide properties. + * + * By default,the [[Routes]] running inside a worker threads, and can run with + * `Virtual Threads` if the runtime supports it and the property `cask.virtualThread.enabled` is set to `true`. + */ +abstract class Main { + /** + * Ture and only true when the virtual thread supported by the runtime, and + * property `cask.virtualThread.enabled` is set to `true`. + * */ + private lazy val useVirtualThreadHandlerExecutor: Boolean = { + val enableVT = System.getProperty("cask.virtualThread.enabled", "false") + enableVT match { + case "true" if VirtualThreadSupport.isVirtualThreadSupported => true + case _ => false + } + + } + def mainDecorators: Seq[Decorator[_, _, _]] = Nil + def allRoutes: Seq[Routes] + def port: Int = 8080 + def host: String = "localhost" + def verbose = false + def debugMode: Boolean = true def createExecutionContext: ExecutionContext = castor.Context.Simple.executionContext + def createActorContext = new castor.Context.Simple(executionContext, log.exception) - val executionContext = createExecutionContext + val executionContext: ExecutionContext = createExecutionContext implicit val actorContext: castor.Context = createActorContext implicit def log: cask.util.Logger = new cask.util.Logger.Console() def dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = Main.prepareDispatchTrie(allRoutes) - def defaultHandler = new BlockingHandler( - new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) - ) + def defaultHandler: HttpHandler = { + val mainHandler = new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) + if (useVirtualThreadHandlerExecutor) { + //switch to virtual thread if possible + new VirtualThreadBlockingHandler(mainHandler) + } else new BlockingHandler(mainHandler) + } + + def handler(): HttpHandler = defaultHandler def handleNotFound(): Raw = Main.defaultHandleNotFound() @@ -73,7 +101,7 @@ abstract class Main{ } -object Main{ +object Main { class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]], mainDecorators: Seq[Decorator[_, _, _]], @nowarn debugMode: Boolean, @@ -87,7 +115,7 @@ object Main{ Tuple2( "websocket", (r: Any) => - r.asInstanceOf[WebsocketResult] match{ + r.asInstanceOf[WebsocketResult] match { case l: WsHandler => io.undertow.Handlers.websocket(l).handleRequest(exchange) case l: WebsocketResult.Listener => @@ -161,7 +189,7 @@ object Main{ val methodMap = methods.toMap[String, (Routes, EndpointMetadata[_])] val subpath = metadata.endpoint.subpath || - metadata.entryPoint.argSignatures.exists(_.exists(_.reads.remainingPathSegments)) + metadata.entryPoint.argSignatures.exists(_.exists(_.reads.remainingPathSegments)) (segments, methodMap, subpath) } @@ -176,10 +204,10 @@ object Main{ } def writeResponse(exchange: HttpServerExchange, response: Response.Raw): Unit = { - response.data.headers.foreach{case (k, v) => + response.data.headers.foreach { case (k, v) => exchange.getResponseHeaders.put(new HttpString(k), v) } - response.headers.foreach{case (k, v) => + response.headers.foreach { case (k, v) => exchange.getResponseHeaders.put(new HttpString(k), v) } response.cookies.foreach(c => exchange.setResponseCookie(Cookie.toUndertow(c))) @@ -188,11 +216,15 @@ object Main{ val output = exchange.getOutputStream response.data.write(new java.io.OutputStream { def write(b: Int): Unit = output.write(b) + override def write(b: Array[Byte]): Unit = output.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len) + override def close(): Unit = { if (!exchange.isComplete) output.close() } + override def flush(): Unit = { if (!exchange.isComplete) output.flush() } diff --git a/example/compress/app/src/Compress.scala b/example/compress/app/src/Compress.scala index 9c57494ad2..e2403a4a79 100644 --- a/example/compress/app/src/Compress.scala +++ b/example/compress/app/src/Compress.scala @@ -3,7 +3,8 @@ object Compress extends cask.MainRoutes{ @cask.decorators.compress @cask.get("/") - def hello() = { + def hello(): String = { + Thread.sleep(1000) // Simulate a slow endpoint "Hello World! Hello World! Hello World!" }