Skip to content

Commit

Permalink
feat: Add virtual threads support.
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed May 19, 2024
1 parent 2525b6a commit f4c523c
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 22 deletions.
14 changes: 14 additions & 0 deletions cask/src/cask/internal/NewThreadPerTaskExecutor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package cask.internal

import java.util.concurrent.{Executor, ThreadFactory}

private[cask] final class NewThreadPerTaskExecutor(val threadFactory: ThreadFactory)
extends Executor {
override def execute(command: Runnable): Unit = {
val thread = threadFactory.newThread(command)
thread.start()
if (thread.getState eq Thread.State.TERMINATED) {
throw new IllegalStateException("Thread has already been terminated")
}
}
}
16 changes: 16 additions & 0 deletions cask/src/cask/internal/VirtualThreadBlockingHandler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package cask.internal

import io.undertow.server.{HttpHandler, HttpServerExchange}

private[cask] final class VirtualThreadBlockingHandler(val handler: HttpHandler)
extends HttpHandler {
override def handleRequest(exchange: HttpServerExchange): Unit = {
exchange.startBlocking()
exchange.dispatch(VirtualThreadBlockingHandler.EXECUTOR, handler)
}
}

private[cask] object VirtualThreadBlockingHandler {
private lazy val EXECUTOR = new NewThreadPerTaskExecutor(
VirtualThreadSupport.create("cask"))
}
31 changes: 31 additions & 0 deletions cask/src/cask/internal/VirtualThreadSupport.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package cask.internal

import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.ThreadFactory

private[cask] object VirtualThreadSupport {

/**
* Returns if the current Runtime supports virtual threads.
*/
lazy val isVirtualThreadSupported: Boolean = create("testIfSupported") ne null

/**
* Create a virtual thread factory, returns null when failed.
*/
def create(prefix: String): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val lookup = MethodHandles.lookup
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case _: Throwable => null
}
}
74 changes: 53 additions & 21 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cask.main

import cask.endpoints.{WebsocketResult, WsHandler}
import cask.model._
import cask.internal.{DispatchTrie, Util}
import cask.internal.{DispatchTrie, Util, VirtualThreadBlockingHandler, VirtualThreadSupport}
import cask.model.Response.Raw
import cask.router.{Decorator, EndpointMetadata, EntryPoint, Result}
import cask.util.Logger
Expand All @@ -16,41 +16,69 @@ import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal

/**
* A combination of [[cask.Main]] and [[cask.Routes]], ideal for small
* one-file web applications.
*/
class MainRoutes extends Main with Routes{
* A combination of [[cask.Main]] and [[cask.Routes]], ideal for small
* one-file web applications.
*/
class MainRoutes extends Main with Routes {
def allRoutes: Seq[Routes] = Seq(this)
}

/**
* Defines the main entrypoint and configuration of the Cask web application.
*
* You can pass in an arbitrary number of [[cask.Routes]] objects for it to
* serve, and override various properties on [[Main]] in order to configure
* application-wide properties.
*/
abstract class Main{
* Defines the main entrypoint and configuration of the Cask web application.
*
* You can pass in an arbitrary number of [[cask.Routes]] objects for it to
* serve, and override various properties on [[Main]] in order to configure
* application-wide properties.
*
* By default,the [[Routes]] running inside a worker threads, and can run with
* `Virtual Threads` if the runtime supports it and the property `cask.virtualThread.enabled` is set to `true`.
*/
abstract class Main {
/**
* Ture and only true when the virtual thread supported by the runtime, and
* property `cask.virtualThread.enabled` is set to `true`.
* */
private lazy val useVirtualThreadHandlerExecutor: Boolean = {
val enableVT = System.getProperty("cask.virtualThread.enabled", "false")
enableVT match {
case "true" if VirtualThreadSupport.isVirtualThreadSupported => true
case _ => false
}

}

def mainDecorators: Seq[Decorator[_, _, _]] = Nil

def allRoutes: Seq[Routes]

def port: Int = 8080

def host: String = "localhost"

def verbose = false

def debugMode: Boolean = true

def createExecutionContext: ExecutionContext = castor.Context.Simple.executionContext

def createActorContext = new castor.Context.Simple(executionContext, log.exception)

val executionContext = createExecutionContext
val executionContext: ExecutionContext = createExecutionContext
implicit val actorContext: castor.Context = createActorContext

implicit def log: cask.util.Logger = new cask.util.Logger.Console()

def dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = Main.prepareDispatchTrie(allRoutes)

def defaultHandler = new BlockingHandler(
new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError)
)
def defaultHandler: HttpHandler = {
val mainHandler = new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError)
if (useVirtualThreadHandlerExecutor) {
//switch to virtual thread if possible
new VirtualThreadBlockingHandler(mainHandler)
} else new BlockingHandler(mainHandler)
}

def handler(): HttpHandler = defaultHandler

def handleNotFound(): Raw = Main.defaultHandleNotFound()

Expand All @@ -73,7 +101,7 @@ abstract class Main{

}

object Main{
object Main {
class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]],
mainDecorators: Seq[Decorator[_, _, _]],
@nowarn debugMode: Boolean,
Expand All @@ -87,7 +115,7 @@ object Main{
Tuple2(
"websocket",
(r: Any) =>
r.asInstanceOf[WebsocketResult] match{
r.asInstanceOf[WebsocketResult] match {
case l: WsHandler =>
io.undertow.Handlers.websocket(l).handleRequest(exchange)
case l: WebsocketResult.Listener =>
Expand Down Expand Up @@ -161,7 +189,7 @@ object Main{
val methodMap = methods.toMap[String, (Routes, EndpointMetadata[_])]
val subpath =
metadata.endpoint.subpath ||
metadata.entryPoint.argSignatures.exists(_.exists(_.reads.remainingPathSegments))
metadata.entryPoint.argSignatures.exists(_.exists(_.reads.remainingPathSegments))

(segments, methodMap, subpath)
}
Expand All @@ -176,10 +204,10 @@ object Main{
}

def writeResponse(exchange: HttpServerExchange, response: Response.Raw): Unit = {
response.data.headers.foreach{case (k, v) =>
response.data.headers.foreach { case (k, v) =>
exchange.getResponseHeaders.put(new HttpString(k), v)
}
response.headers.foreach{case (k, v) =>
response.headers.foreach { case (k, v) =>
exchange.getResponseHeaders.put(new HttpString(k), v)
}
response.cookies.foreach(c => exchange.setResponseCookie(Cookie.toUndertow(c)))
Expand All @@ -188,11 +216,15 @@ object Main{
val output = exchange.getOutputStream
response.data.write(new java.io.OutputStream {
def write(b: Int): Unit = output.write(b)

override def write(b: Array[Byte]): Unit = output.write(b)

override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len)

override def close(): Unit = {
if (!exchange.isComplete) output.close()
}

override def flush(): Unit = {
if (!exchange.isComplete) output.flush()
}
Expand Down
3 changes: 2 additions & 1 deletion example/compress/app/src/Compress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ object Compress extends cask.MainRoutes{

@cask.decorators.compress
@cask.get("/")
def hello() = {
def hello(): String = {
Thread.sleep(1000) // Simulate a slow endpoint
"Hello World! Hello World! Hello World!"
}

Expand Down

0 comments on commit f4c523c

Please sign in to comment.