Skip to content

Commit

Permalink
wip: Move helper methods to ThreadBlockingHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed May 26, 2024
1 parent 9d4ffd6 commit dc26b6f
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 101 deletions.
14 changes: 0 additions & 14 deletions cask/src/cask/internal/NewThreadPerTaskExecutor.scala

This file was deleted.

16 changes: 16 additions & 0 deletions cask/src/cask/internal/ThreadBlockingHandler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package cask.internal

import io.undertow.server.{HttpHandler, HttpServerExchange}

import java.util.concurrent.Executor

final class ThreadBlockingHandler(executor: Executor,
handler: HttpHandler)
extends HttpHandler {
require(executor != null, "executor should not be null")

override def handleRequest(exchange: HttpServerExchange): Unit = {
exchange.startBlocking()
exchange.dispatch(executor, handler)
}
}
80 changes: 60 additions & 20 deletions cask/src/cask/internal/Util.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,61 @@
package cask.internal

import java.io.{InputStream, PrintWriter, StringWriter}

import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import java.io.OutputStream

import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{ExecutorService, ThreadFactory}
import scala.annotation.switch
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.control.NonFatal

object Util {
private val lookup = MethodHandles.lookup

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory)
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
throw new UnsupportedOperationException("Failed to create virtual thread executor", e)
}
}

/**
* Create a virtual thread factory, returns null when failed.
*/
def createVirtualThreadFactory(prefix: String): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case _: Throwable => null
}

def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = {
val p = Promise[T]
futures.foreach(_.foreach(p.trySuccess))
p.future
}

/**
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
def literalize(s: IndexedSeq[Char], unicode: Boolean = true) = {
val sb = new StringBuilder
sb.append('"')
Expand Down Expand Up @@ -47,29 +84,30 @@ object Util {
def transferTo(in: InputStream, out: OutputStream) = {
val buffer = new Array[Byte](8192)

while ({
in.read(buffer) match{
while ( {
in.read(buffer) match {
case -1 => false
case n =>
out.write(buffer, 0, n)
true
}
}) ()
}

def pluralize(s: String, n: Int) = {
if (n == 1) s else s + "s"
}

/**
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
def splitPath(p: String): collection.IndexedSeq[String] = {
val pLength = p.length
var i = 0
while(i < pLength && p(i) == '/') i += 1
while (i < pLength && p(i) == '/') i += 1
var segmentStart = i
val out = mutable.ArrayBuffer.empty[String]

Expand All @@ -81,7 +119,7 @@ object Util {
segmentStart = i + 1
}

while(i < pLength){
while (i < pLength) {
if (p(i) == '/') complete()
i += 1
}
Expand All @@ -96,33 +134,35 @@ object Util {
pw.flush()
trace.toString
}

def softWrap(s: String, leftOffset: Int, maxWidth: Int) = {
val oneLine = s.linesIterator.mkString(" ").split(' ')

lazy val indent = " " * leftOffset

val output = new StringBuilder(oneLine.head)
var currentLineWidth = oneLine.head.length
for(chunk <- oneLine.tail){
for (chunk <- oneLine.tail) {
val addedWidth = currentLineWidth + chunk.length + 1
if (addedWidth > maxWidth){
if (addedWidth > maxWidth) {
output.append("\n" + indent)
output.append(chunk)
currentLineWidth = chunk.length
} else{
} else {
currentLineWidth = addedWidth
output.append(' ')
output.append(chunk)
}
}
output.mkString
}

def sequenceEither[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])(
implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = {
in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) {
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
.map(_.result())
}
}
16 changes: 0 additions & 16 deletions cask/src/cask/internal/VirtualThreadBlockingHandler.scala

This file was deleted.

31 changes: 0 additions & 31 deletions cask/src/cask/internal/VirtualThreadSupport.scala

This file was deleted.

28 changes: 10 additions & 18 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package cask.main

import cask.endpoints.{WebsocketResult, WsHandler}
import cask.model._
import cask.internal.{DispatchTrie, Util, VirtualThreadBlockingHandler, VirtualThreadSupport}
import cask.model.Response.Raw
import cask.internal.{DispatchTrie, Util, ThreadBlockingHandler}
import Response.Raw
import cask.router.{Decorator, EndpointMetadata, EntryPoint, Result}
import cask.util.Logger
import io.undertow.Undertow
import io.undertow.server.{HttpHandler, HttpServerExchange}
import io.undertow.server.handlers.BlockingHandler
import io.undertow.util.HttpString

import java.util.concurrent.Executor
import scala.annotation.nowarn
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
Expand All @@ -34,19 +35,6 @@ class MainRoutes extends Main with Routes {
* `Virtual Threads` if the runtime supports it and the property `cask.virtualThread.enabled` is set to `true`.
*/
abstract class Main {
/**
* Ture and only true when the virtual thread supported by the runtime, and
* property `cask.virtualThread.enabled` is set to `true`.
* */
private lazy val useVirtualThreadHandlerExecutor: Boolean = {
val enableVT = System.getProperty("cask.virtualThread.enabled", "false")
enableVT match {
case "true" if VirtualThreadSupport.isVirtualThreadSupported => true
case _ => false
}

}

def mainDecorators: Seq[Decorator[_, _, _]] = Nil

def allRoutes: Seq[Routes]
Expand All @@ -70,11 +58,15 @@ abstract class Main {

def dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = Main.prepareDispatchTrie(allRoutes)

protected def handlerExecutor(): Executor = {
null
}

def defaultHandler: HttpHandler = {
val mainHandler = new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError)
if (useVirtualThreadHandlerExecutor) {
//switch to virtual thread if possible
new VirtualThreadBlockingHandler(mainHandler)
val executor = handlerExecutor()
if (handlerExecutor ne null) {
new ThreadBlockingHandler(executor, mainHandler)
} else new BlockingHandler(mainHandler)
}

Expand Down
15 changes: 13 additions & 2 deletions example/compress/app/src/Compress.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
package app
object Compress extends cask.MainRoutes{

import cask.internal.{ThreadBlockingHandler, Util}

import java.util.concurrent.Executor

object Compress extends cask.MainRoutes {

protected override val handlerExecutor: Executor = {
if (System.getProperty("cask.virtualThread.enabled", "false").toBoolean) {
Util.createNewThreadPerTaskExecutor(
Util.createVirtualThreadFactory("cask-handler-executor"))
} else null
}

@cask.decorators.compress
@cask.get("/")
def hello(): String = {
Thread.sleep(1000) // Simulate a slow endpoint
"Hello World! Hello World! Hello World!"
}

Expand Down

0 comments on commit dc26b6f

Please sign in to comment.