Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Virtual Thread support #125

Closed
wants to merge 13 commits into from
3 changes: 1 addition & 2 deletions .mill-version
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
0.11.6

0.11.8
36 changes: 18 additions & 18 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ import $file.example.websockets3.build
import $file.example.websockets4.build
import $file.ci.upload
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version::0.4.0`
import $ivy.`com.github.lolgab::mill-mima::0.0.23`
import $ivy.`com.github.lolgab::mill-mima::0.1.0`
import de.tobiasroeser.mill.vcs.version.VcsVersion

val scala213 = "2.13.10"
val scala212 = "2.12.17"
val scala3 = "3.2.2"
val scalaJS = "1.13.0"
val scala213 = "2.13.14"
val scala212 = "2.12.19"
val scala3 = "3.3.3"
val scalaJS = "1.16.0"
val communityBuildDottyVersion = sys.props.get("dottyVersion").toList

val scalaVersions = List(scala212, scala213, scala3) ++ communityBuildDottyVersion
Expand All @@ -59,20 +59,20 @@ trait CaskModule extends CrossScalaModule with PublishModule{
trait CaskMainModule extends CaskModule {
def ivyDeps = T{
Agg(
ivy"io.undertow:undertow-core:2.3.10.Final",
ivy"com.lihaoyi::upickle:3.0.0"
ivy"io.undertow:undertow-core:2.3.13.Final",
ivy"com.lihaoyi::upickle:3.3.1"
) ++
Agg.when(!isScala3)(ivy"org.scala-lang:scala-reflect:$crossScalaVersion")
}

def compileIvyDeps = Agg.when(!isScala3)(ivy"com.lihaoyi:::acyclic:0.3.6")
def compileIvyDeps = Agg.when(!isScala3)(ivy"com.lihaoyi:::acyclic:0.3.12")
def scalacOptions = Agg.when(!isScala3)("-P:acyclic:force").toSeq
def scalacPluginIvyDeps = Agg.when(!isScala3)(ivy"com.lihaoyi:::acyclic:0.3.6")
def scalacPluginIvyDeps = Agg.when(!isScala3)(ivy"com.lihaoyi:::acyclic:0.3.12")

object test extends ScalaTests with TestModule.Utest{
def ivyDeps = Agg(
ivy"com.lihaoyi::utest::0.8.1",
ivy"com.lihaoyi::requests::0.8.0"
ivy"com.lihaoyi::utest::0.8.3",
ivy"com.lihaoyi::requests::0.8.2"
)
}
def moduleDeps = Seq(cask.util.jvm(crossScalaVersion))
Expand All @@ -82,17 +82,17 @@ object cask extends Cross[CaskMainModule](scalaVersions) {
object util extends Module {
trait UtilModule extends CaskModule with PlatformScalaModule{
def ivyDeps = Agg(
ivy"com.lihaoyi::sourcecode:0.3.0",
ivy"com.lihaoyi::pprint:0.8.1",
ivy"com.lihaoyi::geny:1.0.0"
ivy"com.lihaoyi::sourcecode:0.4.1",
ivy"com.lihaoyi::pprint:0.9.0",
ivy"com.lihaoyi::geny:1.1.0"
)
}

object jvm extends Cross[UtilJvmModule](scalaVersions)
trait UtilJvmModule extends UtilModule {
def ivyDeps = super.ivyDeps() ++ Agg(
ivy"com.lihaoyi::castor::0.3.0",
ivy"org.java-websocket:Java-WebSocket:1.5.3"
ivy"org.java-websocket:Java-WebSocket:1.5.6"
)
}

Expand All @@ -101,7 +101,7 @@ object cask extends Cross[CaskMainModule](scalaVersions) {
def scalaJSVersion = scalaJS
def ivyDeps = super.ivyDeps() ++ Agg(
ivy"com.lihaoyi::castor::0.3.0",
ivy"org.scala-js::scalajs-dom::2.4.0"
ivy"org.scala-js::scalajs-dom::2.8.0"
)
}
}
Expand Down Expand Up @@ -168,7 +168,7 @@ object example extends Module{
object todoDb extends Cross[TodoDbModule](scala213) // uses quill, can't enable for Dotty yet

trait TwirlModule extends millbuild.example.twirl.build.AppModule with LocalModule
object twirl extends Cross[TwirlModule](scalaVersions)
// object twirl extends Cross[TwirlModule](scalaVersions)

trait VariableRoutesModule extends millbuild.example.variableRoutes.build.AppModule with LocalModule
object variableRoutes extends Cross[VariableRoutesModule](scalaVersions)
Expand Down Expand Up @@ -230,7 +230,7 @@ def uploadToGithub() = T.command{
millbuild.example.todo.build.millSourcePath,
millbuild.example.todoApi.build.millSourcePath,
millbuild.example.todoDb.build.millSourcePath,
millbuild.example.twirl.build.millSourcePath,
// millbuild.example.twirl.build.millSourcePath,
millbuild.example.variableRoutes.build.millSourcePath,
millbuild.example.queryParams.build.millSourcePath,
millbuild.example.websockets.build.millSourcePath,
Expand Down
4 changes: 2 additions & 2 deletions cask/src-2/cask/main/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import language.experimental.macros
trait Routes{

def decorators = Seq.empty[cask.router.Decorator[_, _, _]]
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null
def caskMetadata =
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = _
def caskMetadata: RoutesEndpointsMetadata[Routes.this.type] =
if (metadata0 != null) metadata0
else throw new Exception("Routes not yet initialized")

Expand Down
4 changes: 2 additions & 2 deletions cask/src-3/cask/main/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import language.experimental.macros
trait Routes{

def decorators = Seq.empty[cask.router.Decorator[_, _, _]]
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null
def caskMetadata =
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = _
def caskMetadata: RoutesEndpointsMetadata[Routes.this.type] =
if (metadata0 != null) metadata0
else throw new Exception("Routes not yet initialized")

Expand Down
16 changes: 16 additions & 0 deletions cask/src/cask/internal/ThreadBlockingHandler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package cask.internal

import io.undertow.server.{HttpHandler, HttpServerExchange}

import java.util.concurrent.Executor

final class ThreadBlockingHandler(executor: Executor,
handler: HttpHandler)
extends HttpHandler {
require(executor != null, "executor should not be null")

override def handleRequest(exchange: HttpServerExchange): Unit = {
exchange.startBlocking()
exchange.dispatch(executor, handler)
}
}
113 changes: 93 additions & 20 deletions cask/src/cask/internal/Util.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,94 @@
package cask.internal

import java.io.{InputStream, PrintWriter, StringWriter}

import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import java.io.OutputStream

import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{Executor, ExecutorService, ThreadFactory}
import scala.annotation.switch
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Try
import scala.util.control.NonFatal

object Util {
private val lookup = MethodHandles.lookup

import cask.util.Logger.Console.globalLogger

/**
* Create a virtual thread executor with the given executor as the scheduler.
* */
def createVirtualThreadExecutor(executor: Executor): Option[Executor] = {
(for {
factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor))
executor <- Try(createNewThreadPerTaskExecutor(factory))
} yield executor).toOption
}

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory)
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
}
}

/**
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
def createVirtualThreadFactory(prefix: String,
executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
lookup
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
//--add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory", e)
}

def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = {
val p = Promise[T]
futures.foreach(_.foreach(p.trySuccess))
p.future
}

/**
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
def literalize(s: IndexedSeq[Char], unicode: Boolean = true) = {
val sb = new StringBuilder
sb.append('"')
Expand Down Expand Up @@ -47,29 +117,30 @@ object Util {
def transferTo(in: InputStream, out: OutputStream) = {
val buffer = new Array[Byte](8192)

while ({
in.read(buffer) match{
while ( {
in.read(buffer) match {
case -1 => false
case n =>
out.write(buffer, 0, n)
true
}
}) ()
}

def pluralize(s: String, n: Int) = {
if (n == 1) s else s + "s"
}

/**
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
def splitPath(p: String): collection.IndexedSeq[String] = {
val pLength = p.length
var i = 0
while(i < pLength && p(i) == '/') i += 1
while (i < pLength && p(i) == '/') i += 1
var segmentStart = i
val out = mutable.ArrayBuffer.empty[String]

Expand All @@ -81,7 +152,7 @@ object Util {
segmentStart = i + 1
}

while(i < pLength){
while (i < pLength) {
if (p(i) == '/') complete()
i += 1
}
Expand All @@ -96,33 +167,35 @@ object Util {
pw.flush()
trace.toString
}

def softWrap(s: String, leftOffset: Int, maxWidth: Int) = {
val oneLine = s.linesIterator.mkString(" ").split(' ')

lazy val indent = " " * leftOffset

val output = new StringBuilder(oneLine.head)
var currentLineWidth = oneLine.head.length
for(chunk <- oneLine.tail){
for (chunk <- oneLine.tail) {
val addedWidth = currentLineWidth + chunk.length + 1
if (addedWidth > maxWidth){
if (addedWidth > maxWidth) {
output.append("\n" + indent)
output.append(chunk)
currentLineWidth = chunk.length
} else{
} else {
currentLineWidth = addedWidth
output.append(' ')
output.append(chunk)
}
}
output.mkString
}

def sequenceEither[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])(
implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = {
in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) {
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
.map(_.result())
}
}
Loading
Loading