Skip to content

Commit

Permalink
wip: Move helper methods to ThreadBlockingHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed May 26, 2024
1 parent 9d4ffd6 commit 4af0209
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 81 deletions.
14 changes: 0 additions & 14 deletions cask/src/cask/internal/NewThreadPerTaskExecutor.scala

This file was deleted.

55 changes: 55 additions & 0 deletions cask/src/cask/internal/ThreadBlockingHandler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package cask.internal

import io.undertow.server.{HttpHandler, HttpServerExchange}

import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{Executor, ExecutorService, ThreadFactory}
import scala.util.control.NonFatal

final class ThreadBlockingHandler(executor: Executor,
handler: HttpHandler)
extends HttpHandler {
require(executor != null, "executor should not be null")

override def handleRequest(exchange: HttpServerExchange): Unit = {
exchange.startBlocking()
exchange.dispatch(executor, handler)
}
}

object ThreadBlockingHandler {
private val lookup = MethodHandles.lookup

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory)
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
throw new UnsupportedOperationException("Failed to create virtual thread executor", e)
}
}

/**
* Create a virtual thread factory, returns null when failed.
*/
def createVirtualThreadFactory(prefix: String): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case _: Throwable => null
}
}
16 changes: 0 additions & 16 deletions cask/src/cask/internal/VirtualThreadBlockingHandler.scala

This file was deleted.

31 changes: 0 additions & 31 deletions cask/src/cask/internal/VirtualThreadSupport.scala

This file was deleted.

28 changes: 10 additions & 18 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package cask.main

import cask.endpoints.{WebsocketResult, WsHandler}
import cask.model._
import cask.internal.{DispatchTrie, Util, VirtualThreadBlockingHandler, VirtualThreadSupport}
import cask.model.Response.Raw
import cask.internal.{DispatchTrie, Util, ThreadBlockingHandler}
import Response.Raw
import cask.router.{Decorator, EndpointMetadata, EntryPoint, Result}
import cask.util.Logger
import io.undertow.Undertow
import io.undertow.server.{HttpHandler, HttpServerExchange}
import io.undertow.server.handlers.BlockingHandler
import io.undertow.util.HttpString

import java.util.concurrent.Executor
import scala.annotation.nowarn
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
Expand All @@ -34,19 +35,6 @@ class MainRoutes extends Main with Routes {
* `Virtual Threads` if the runtime supports it and the property `cask.virtualThread.enabled` is set to `true`.
*/
abstract class Main {
/**
* Ture and only true when the virtual thread supported by the runtime, and
* property `cask.virtualThread.enabled` is set to `true`.
* */
private lazy val useVirtualThreadHandlerExecutor: Boolean = {
val enableVT = System.getProperty("cask.virtualThread.enabled", "false")
enableVT match {
case "true" if VirtualThreadSupport.isVirtualThreadSupported => true
case _ => false
}

}

def mainDecorators: Seq[Decorator[_, _, _]] = Nil

def allRoutes: Seq[Routes]
Expand All @@ -70,11 +58,15 @@ abstract class Main {

def dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = Main.prepareDispatchTrie(allRoutes)

protected def handlerExecutor(): Executor = {
null
}

def defaultHandler: HttpHandler = {
val mainHandler = new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError)
if (useVirtualThreadHandlerExecutor) {
//switch to virtual thread if possible
new VirtualThreadBlockingHandler(mainHandler)
val executor = handlerExecutor()
if (handlerExecutor ne null) {
new ThreadBlockingHandler(executor, mainHandler)
} else new BlockingHandler(mainHandler)
}

Expand Down
15 changes: 13 additions & 2 deletions example/compress/app/src/Compress.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
package app
object Compress extends cask.MainRoutes{

import cask.internal.ThreadBlockingHandler

import java.util.concurrent.Executor

object Compress extends cask.MainRoutes {

protected override val handlerExecutor: Executor = {
if (System.getProperty("cask.virtualThread.enabled", "false").toBoolean) {
ThreadBlockingHandler.createNewThreadPerTaskExecutor(
ThreadBlockingHandler.createVirtualThreadFactory("cask-handler-executor"))
} else null
}

@cask.decorators.compress
@cask.get("/")
def hello(): String = {
Thread.sleep(1000) // Simulate a slow endpoint
"Hello World! Hello World! Hello World!"
}

Expand Down

0 comments on commit 4af0209

Please sign in to comment.