From 4af0209323c88245b2b1ec1bf242228977ec7e39 Mon Sep 17 00:00:00 2001 From: "He-Pin(kerr)" Date: Mon, 27 May 2024 03:03:17 +0800 Subject: [PATCH] wip: Move helper methods to ThreadBlockingHandler --- .../internal/NewThreadPerTaskExecutor.scala | 14 ----- .../cask/internal/ThreadBlockingHandler.scala | 55 +++++++++++++++++++ .../VirtualThreadBlockingHandler.scala | 16 ------ .../cask/internal/VirtualThreadSupport.scala | 31 ----------- cask/src/cask/main/Main.scala | 28 ++++------ example/compress/app/src/Compress.scala | 15 ++++- 6 files changed, 78 insertions(+), 81 deletions(-) delete mode 100644 cask/src/cask/internal/NewThreadPerTaskExecutor.scala create mode 100644 cask/src/cask/internal/ThreadBlockingHandler.scala delete mode 100644 cask/src/cask/internal/VirtualThreadBlockingHandler.scala delete mode 100644 cask/src/cask/internal/VirtualThreadSupport.scala diff --git a/cask/src/cask/internal/NewThreadPerTaskExecutor.scala b/cask/src/cask/internal/NewThreadPerTaskExecutor.scala deleted file mode 100644 index a0988f47ed..0000000000 --- a/cask/src/cask/internal/NewThreadPerTaskExecutor.scala +++ /dev/null @@ -1,14 +0,0 @@ -package cask.internal - -import java.util.concurrent.{Executor, ThreadFactory} - -private[cask] final class NewThreadPerTaskExecutor(val threadFactory: ThreadFactory) - extends Executor { - override def execute(command: Runnable): Unit = { - val thread = threadFactory.newThread(command) - thread.start() - if (thread.getState eq Thread.State.TERMINATED) { - throw new IllegalStateException("Thread has already been terminated") - } - } -} diff --git a/cask/src/cask/internal/ThreadBlockingHandler.scala b/cask/src/cask/internal/ThreadBlockingHandler.scala new file mode 100644 index 0000000000..e66134da39 --- /dev/null +++ b/cask/src/cask/internal/ThreadBlockingHandler.scala @@ -0,0 +1,55 @@ +package cask.internal + +import io.undertow.server.{HttpHandler, HttpServerExchange} + +import java.lang.invoke.{MethodHandles, MethodType} +import java.util.concurrent.{Executor, ExecutorService, ThreadFactory} +import scala.util.control.NonFatal + +final class ThreadBlockingHandler(executor: Executor, + handler: HttpHandler) + extends HttpHandler { + require(executor != null, "executor should not be null") + + override def handleRequest(exchange: HttpServerExchange): Unit = { + exchange.startBlocking() + exchange.dispatch(executor, handler) + } +} + +object ThreadBlockingHandler { + private val lookup = MethodHandles.lookup + + def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = { + try { + val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors") + val newThreadPerTaskExecutorMethod = lookup.findStatic( + executorsClazz, + "newThreadPerTaskExecutor", + MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory])) + newThreadPerTaskExecutorMethod.invoke(threadFactory) + .asInstanceOf[ExecutorService] + } catch { + case NonFatal(e) => + throw new UnsupportedOperationException("Failed to create virtual thread executor", e) + } + } + + /** + * Create a virtual thread factory, returns null when failed. + */ + def createVirtualThreadFactory(prefix: String): ThreadFactory = + try { + val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") + val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") + val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass)) + var builder = ofVirtualMethod.invoke() + val nameMethod = lookup.findVirtual(ofVirtualClass, "name", + MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])) + val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory])) + builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L) + factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] + } catch { + case _: Throwable => null + } +} diff --git a/cask/src/cask/internal/VirtualThreadBlockingHandler.scala b/cask/src/cask/internal/VirtualThreadBlockingHandler.scala deleted file mode 100644 index 2c3ac70b22..0000000000 --- a/cask/src/cask/internal/VirtualThreadBlockingHandler.scala +++ /dev/null @@ -1,16 +0,0 @@ -package cask.internal - -import io.undertow.server.{HttpHandler, HttpServerExchange} - -private[cask] final class VirtualThreadBlockingHandler(val handler: HttpHandler) - extends HttpHandler { - override def handleRequest(exchange: HttpServerExchange): Unit = { - exchange.startBlocking() - exchange.dispatch(VirtualThreadBlockingHandler.EXECUTOR, handler) - } -} - -private[cask] object VirtualThreadBlockingHandler { - private lazy val EXECUTOR = new NewThreadPerTaskExecutor( - VirtualThreadSupport.create("cask")) -} diff --git a/cask/src/cask/internal/VirtualThreadSupport.scala b/cask/src/cask/internal/VirtualThreadSupport.scala deleted file mode 100644 index bb9ceddc08..0000000000 --- a/cask/src/cask/internal/VirtualThreadSupport.scala +++ /dev/null @@ -1,31 +0,0 @@ -package cask.internal - -import java.lang.invoke.{MethodHandles, MethodType} -import java.util.concurrent.ThreadFactory - -private[cask] object VirtualThreadSupport { - - /** - * Returns if the current Runtime supports virtual threads. - */ - lazy val isVirtualThreadSupported: Boolean = create("testIfSupported") ne null - - /** - * Create a virtual thread factory, returns null when failed. - */ - def create(prefix: String): ThreadFactory = - try { - val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") - val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") - val lookup = MethodHandles.lookup - val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass)) - var builder = ofVirtualMethod.invoke() - val nameMethod = lookup.findVirtual(ofVirtualClass, "name", - MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])) - val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory])) - builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L) - factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] - } catch { - case _: Throwable => null - } -} diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index a5757804d5..6423211ce3 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -2,8 +2,8 @@ package cask.main import cask.endpoints.{WebsocketResult, WsHandler} import cask.model._ -import cask.internal.{DispatchTrie, Util, VirtualThreadBlockingHandler, VirtualThreadSupport} -import cask.model.Response.Raw +import cask.internal.{DispatchTrie, Util, ThreadBlockingHandler} +import Response.Raw import cask.router.{Decorator, EndpointMetadata, EntryPoint, Result} import cask.util.Logger import io.undertow.Undertow @@ -11,6 +11,7 @@ import io.undertow.server.{HttpHandler, HttpServerExchange} import io.undertow.server.handlers.BlockingHandler import io.undertow.util.HttpString +import java.util.concurrent.Executor import scala.annotation.nowarn import scala.concurrent.ExecutionContext import scala.util.control.NonFatal @@ -34,19 +35,6 @@ class MainRoutes extends Main with Routes { * `Virtual Threads` if the runtime supports it and the property `cask.virtualThread.enabled` is set to `true`. */ abstract class Main { - /** - * Ture and only true when the virtual thread supported by the runtime, and - * property `cask.virtualThread.enabled` is set to `true`. - * */ - private lazy val useVirtualThreadHandlerExecutor: Boolean = { - val enableVT = System.getProperty("cask.virtualThread.enabled", "false") - enableVT match { - case "true" if VirtualThreadSupport.isVirtualThreadSupported => true - case _ => false - } - - } - def mainDecorators: Seq[Decorator[_, _, _]] = Nil def allRoutes: Seq[Routes] @@ -70,11 +58,15 @@ abstract class Main { def dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = Main.prepareDispatchTrie(allRoutes) + protected def handlerExecutor(): Executor = { + null + } + def defaultHandler: HttpHandler = { val mainHandler = new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) - if (useVirtualThreadHandlerExecutor) { - //switch to virtual thread if possible - new VirtualThreadBlockingHandler(mainHandler) + val executor = handlerExecutor() + if (handlerExecutor ne null) { + new ThreadBlockingHandler(executor, mainHandler) } else new BlockingHandler(mainHandler) } diff --git a/example/compress/app/src/Compress.scala b/example/compress/app/src/Compress.scala index e2403a4a79..d30bfa7217 100644 --- a/example/compress/app/src/Compress.scala +++ b/example/compress/app/src/Compress.scala @@ -1,10 +1,21 @@ package app -object Compress extends cask.MainRoutes{ + +import cask.internal.ThreadBlockingHandler + +import java.util.concurrent.Executor + +object Compress extends cask.MainRoutes { + + protected override val handlerExecutor: Executor = { + if (System.getProperty("cask.virtualThread.enabled", "false").toBoolean) { + ThreadBlockingHandler.createNewThreadPerTaskExecutor( + ThreadBlockingHandler.createVirtualThreadFactory("cask-handler-executor")) + } else null + } @cask.decorators.compress @cask.get("/") def hello(): String = { - Thread.sleep(1000) // Simulate a slow endpoint "Hello World! Hello World! Hello World!" }