diff --git a/cask/src/cask/internal/Util.scala b/cask/src/cask/internal/Util.scala index e550e3b1b5..1511591610 100644 --- a/cask/src/cask/internal/Util.scala +++ b/cask/src/cask/internal/Util.scala @@ -5,14 +5,27 @@ import scala.collection.generic.CanBuildFrom import scala.collection.mutable import java.io.OutputStream import java.lang.invoke.{MethodHandles, MethodType} -import java.util.concurrent.{ExecutorService, ThreadFactory} +import java.util.concurrent.{Executor, ExecutorService, ThreadFactory} import scala.annotation.switch import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.util.Try import scala.util.control.NonFatal object Util { private val lookup = MethodHandles.lookup + import cask.util.Logger.Console.globalLogger + + /** + * Create a virtual thread executor with the given executor as the scheduler. + * */ + def createVirtualThreadExecutor(executor: Executor): Option[Executor] = { + (for { + factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor)) + executor <- Try(createNewThreadPerTaskExecutor(factory)) + } yield executor).toOption + } + def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = { try { val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors") @@ -24,26 +37,46 @@ object Util { .asInstanceOf[ExecutorService] } catch { case NonFatal(e) => - throw new UnsupportedOperationException("Failed to create virtual thread executor", e) + globalLogger.exception(e) + throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e) } } /** - * Create a virtual thread factory, returns null when failed. + * Create a virtual thread factory with a executor, the executor will be used as the scheduler of + * virtual thread. + * + * The executor should run task on platform threads. + * + * returns null if not supported. */ - def createVirtualThreadFactory(prefix: String): ThreadFactory = + def createVirtualThreadFactory(prefix: String, + executor: Executor): ThreadFactory = try { val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass)) var builder = ofVirtualMethod.invoke() + if (executor != null) { + val clazz = builder.getClass + val privateLookup = MethodHandles.privateLookupIn( + clazz, + lookup + ) + val schedulerFieldSetter = privateLookup + .findSetter(clazz, "scheduler", classOf[Executor]) + schedulerFieldSetter.invoke(builder, executor) + } val nameMethod = lookup.findVirtual(ofVirtualClass, "name", MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])) val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory])) builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L) factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] } catch { - case _: Throwable => null + case NonFatal(e) => + globalLogger.exception(e) + //--add-opens java.base/java.lang=ALL-UNNAMED + throw new UnsupportedOperationException("Failed to create virtual thread factory", e) } def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = { diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index 6423211ce3..a8a0b81ddc 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -2,7 +2,7 @@ package cask.main import cask.endpoints.{WebsocketResult, WsHandler} import cask.model._ -import cask.internal.{DispatchTrie, Util, ThreadBlockingHandler} +import cask.internal.{DispatchTrie, ThreadBlockingHandler, Util} import Response.Raw import cask.router.{Decorator, EndpointMetadata, EntryPoint, Result} import cask.util.Logger @@ -62,9 +62,16 @@ abstract class Main { null } + private def screenExecutor(executor: Executor): Executor = { + if (executor eq null) executor + else if (System.getProperty("cask.virtualThread.enabled", "true").toBoolean) { + Util.createVirtualThreadExecutor(executor).getOrElse(executor) + } else executor + } + def defaultHandler: HttpHandler = { val mainHandler = new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) - val executor = handlerExecutor() + val executor = screenExecutor(handlerExecutor()) if (handlerExecutor ne null) { new ThreadBlockingHandler(executor, mainHandler) } else new BlockingHandler(mainHandler) diff --git a/example/compress/app/src/Compress.scala b/example/compress/app/src/Compress.scala index d701aaa4f8..eb95e17101 100644 --- a/example/compress/app/src/Compress.scala +++ b/example/compress/app/src/Compress.scala @@ -2,17 +2,10 @@ package app import cask.internal.{ThreadBlockingHandler, Util} -import java.util.concurrent.Executor +import java.util.concurrent.{Executor, Executors} object Compress extends cask.MainRoutes { - protected override val handlerExecutor: Executor = { - if (System.getProperty("cask.virtualThread.enabled", "false").toBoolean) { - Util.createNewThreadPerTaskExecutor( - Util.createVirtualThreadFactory("cask-handler-executor")) - } else null - } - @cask.decorators.compress @cask.get("/") def hello(): String = {