Skip to content

Commit

Permalink
require per-parameter unroll annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Feb 17, 2024
1 parent 2b3925e commit 56b1757
Show file tree
Hide file tree
Showing 17 changed files with 79 additions and 79 deletions.
10 changes: 5 additions & 5 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ trait UnrollModule extends Cross.Module[String]{
"primaryConstructor",
"secondaryConstructor",
"caseclass",
"abstractTraitMethod",
"abstractClassMethod"
// "abstractTraitMethod",
// "abstractClassMethod"
)


Expand Down Expand Up @@ -146,15 +146,15 @@ trait UnrollModule extends Cross.Module[String]{
}

object jvm extends InnerScalaModule with ComparativePlatformScalaModule{
def runClasspath = Seq(upstreamTest.jvm.test.compile().classes, upstream.jvm.compile().classes)
def runClasspath = super.runClasspath() ++ Seq(upstreamTest.jvm.test.compile().classes, upstream.jvm.compile().classes)
}

object js extends InnerScalaJsModule with ComparativePlatformScalaModule{
def runClasspath = Seq(upstreamTest.js.test.compile().classes, upstream.js.compile().classes)
def runClasspath = super.runClasspath() ++ Seq(upstreamTest.js.test.compile().classes, upstream.js.compile().classes)
}

object native extends InnerScalaNativeModule with ComparativePlatformScalaModule{
def runClasspath = Seq(upstreamTest.native.test.compile().classes, upstream.native.compile().classes)
def runClasspath = super.runClasspath() ++ Seq(upstreamTest.native.test.compile().classes, upstream.native.compile().classes)
}
}

Expand Down
41 changes: 22 additions & 19 deletions unroll/plugin/src-2/UnrollPhaseScala2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
new UnrollTransformer(unit)
}

def findUnrollAnnotation(params: Seq[Symbol]): Int = {
params.toList.indexWhere(_.annotations.exists(_.tpe =:= typeOf[scala.annotation.unroll]))
def findUnrollAnnotations(params: Seq[Symbol]): Seq[Int] = {
params.toList.zipWithIndex.collect {
case (v, i) if v.annotations.exists(_.tpe =:= typeOf[scala.annotation.unroll]) => i
}
}

def copyValDef(vd: ValDef) = {
Expand Down Expand Up @@ -75,8 +77,8 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT

val forwardedValueParams = firstParamList.take(paramIndex).map(p => Ident(p.name).set(p.symbol))

val defaultCalls = for (p <- Range(paramIndex, firstParamList.size)) yield {
val mangledName = defdef.name.toString + "$default$" + (p + 1)
val defaultCalls = Range(paramIndex, firstParamList.size).map{n =>
val mangledName = defdef.name.toString + "$default$" + (n + 1)

val defaultOwner =
if (defdef.symbol.isConstructor) implDef.symbol.companionModule
Expand Down Expand Up @@ -132,17 +134,10 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
forwarderDef.substituteSymbols(fromSyms, toSyms).asInstanceOf[DefDef]
}

def generateDefForwarders(implDef: ImplDef, defdef: DefDef, startParamIndex: Int) = defdef.vparamss match {
case Nil => Nil
case firstParamList :: otherParamLists =>
for (paramIndex <- Range(startParamIndex, firstParamList.length).toList) yield {
generateSingleForwarder(implDef, defdef, paramIndex, firstParamList, otherParamLists)
}
}


class UnrollTransformer(unit: global.CompilationUnit) extends TypingTransformer(unit) {
def generateDefForwarders2(implDef: ImplDef): List[List[DefDef]] = {
def generateDefForwarders(implDef: ImplDef): List[List[DefDef]] = {
implDef.impl.body.collect{ case defdef: DefDef =>

val annotatedOpt =
Expand All @@ -160,11 +155,19 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
// do not have companion class primary constructor symbols, so we just skip them here
annotatedOpt.toList.flatMap{ annotated =>
try {
annotated.asMethod.paramss.take(1).flatMap{ firstParams =>
findUnrollAnnotation(firstParams) match {
case -1 => Nil
case n => generateDefForwarders(implDef, defdef, n)
}
defdef.vparamss match {
case Nil => Nil
case firstParamList :: otherParamLists =>
val annotations = findUnrollAnnotations(annotated.tpe.params)
for (paramIndex <- annotations) yield {
generateSingleForwarder(
implDef,
defdef,
paramIndex,
firstParamList,
otherParamLists
)
}
}
}catch{case e: Throwable =>
throw new Exception(
Expand All @@ -179,7 +182,7 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
override def transform(tree: global.Tree): global.Tree = {
tree match{
case md: ModuleDef =>
val allNewMethods = generateDefForwarders2(md).flatten
val allNewMethods = generateDefForwarders(md).flatten

val classInfoType = md.symbol.moduleClass.info.asInstanceOf[ClassInfoType]
val newClassInfoType = classInfoType.copy(decls = newScopeWith(allNewMethods.map(_.symbol) ++ classInfoType.decls:_*))
Expand All @@ -199,7 +202,7 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
)
)
case cd: ClassDef =>
val allNewMethods = generateDefForwarders2(cd).flatten
val allNewMethods = generateDefForwarders(cd).flatten
super.transform(
treeCopy.ClassDef(
cd,
Expand Down
58 changes: 31 additions & 27 deletions unroll/plugin/src-3/UnrollPhaseScala3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class UnrollPhaseScala3() extends PluginPhase {
paramLists: List[ParamClause],
firstValueParamClauseIndex: Int,
paramIndex: Int,
isCaseApply: Boolean)(using Context) = {
isCaseApply: Boolean)
(using Context) = {

def truncateMethodType0(tpe: Type): Type = {
tpe match{
Expand All @@ -55,7 +56,6 @@ class UnrollPhaseScala3() extends PluginPhase {
}

val truncatedMethodType = truncateMethodType0(prevMethodType)

val forwarderDefSymbol = Symbols.newSymbol(
defdef.symbol.owner,
defdef.name,
Expand All @@ -71,7 +71,7 @@ class UnrollPhaseScala3() extends PluginPhase {
}
}

val defaultCalls = for (n <- Range(paramIndex, paramLists(firstValueParamClauseIndex).size)) yield {
val defaultCalls = Range(paramIndex, paramLists(firstValueParamClauseIndex).size).map(n =>
if (defdef.symbol.isConstructor) {
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.name, n))
Expand All @@ -82,7 +82,7 @@ class UnrollPhaseScala3() extends PluginPhase {
This(defdef.symbol.owner.asClass)
.select(DefaultGetterName(defdef.name, n))
}
}
)

val allNewParamTrees =
updated.zipWithIndex.map{case (ps, i) =>
Expand Down Expand Up @@ -115,14 +115,14 @@ class UnrollPhaseScala3() extends PluginPhase {
newDefDef
}

def generateFromProduct(startParamIndex: Int, paramCount: Int, defdef: DefDef)(using Context) = {
def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
cpy.DefDef(defdef)(
name = defdef.name,
paramss = defdef.paramss,
tpt = defdef.tpt,
rhs = Match(
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
Range(startParamIndex, paramCount).toList.map { paramIndex =>
startParamIndices.map { paramIndex =>
val Apply(select, args) = defdef.rhs
CaseDef(
Literal(Constant(paramIndex)),
Expand Down Expand Up @@ -170,30 +170,34 @@ class UnrollPhaseScala3() extends PluginPhase {
if (firstValueParamClauseIndex == -1) (None, Nil)
else {
val paramCount = annotated.paramSymss(firstValueParamClauseIndex).size
annotated
val startParamIndices = annotated
.paramSymss(firstValueParamClauseIndex)
.indexWhere(_.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll")) match{
case -1 => (None, Nil)
case startParamIndex =>
if (isCaseFromProduct) {
(Some(defdef.symbol), Seq(generateFromProduct(startParamIndex, paramCount, defdef)))
} else {
(
None,
for (paramIndex <- Range(startParamIndex, paramCount)) yield {
generateSingleForwarder(
defdef,
defdef.symbol.info,
defdef.paramss,
firstValueParamClauseIndex,
paramIndex,
isCaseApply
)
}
)
.zipWithIndex
.collect{
case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll") =>
i
}
if (startParamIndices == Nil) (None, Nil)
else if (isCaseFromProduct) {
(Some(defdef.symbol), Seq(generateFromProduct(startParamIndices, paramCount, defdef)))
} else {
(
None,

for (paramIndex <- startParamIndices) yield {
generateSingleForwarder(
defdef,
defdef.symbol.info,
defdef.paramss,
firstValueParamClauseIndex,
paramIndex,
isCaseApply
)
}
)
}
}
}

case _ => (None, Nil)
}

Expand Down
1 change: 1 addition & 0 deletions unroll/tests/caseclass/v2/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ import scala.annotation.unroll
case class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true){
def foo = s + n + b
}

5 changes: 4 additions & 1 deletion unroll/tests/caseclass/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package unroll

import scala.annotation.unroll

case class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0){
case class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0){
def foo = s + n + b + l
}



7 changes: 1 addition & 6 deletions unroll/tests/classMethod/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ package unroll
import scala.annotation.unroll

class Unrolled{
def foo(s: String, @unroll n: Int = 1, b: Boolean = true, l: Long = 0) = s + n + b + l
def foo(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0) = s + n + b + l
}






Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ object UnrollTestPlatformSpecific{
"hello1true0"
)

assert(
cls.getMethod("foo", classOf[String], classOf[Int]).invoke(instance, "hello", 2: Integer) ==
"hello2true0"
)
// Only generate unrolled methods for annotated params
// (b: Boolean) is not annotated so this method should not exist
assert(scala.util.Try(cls.getMethod("foo", classOf[String], classOf[Int])).isFailure)

assert(
cls.getMethod("foo", classOf[String], classOf[Int], classOf[Boolean])
.invoke(instance, "hello", 2: Integer, java.lang.Boolean.FALSE) ==
Expand Down
2 changes: 1 addition & 1 deletion unroll/tests/curriedMethod/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package unroll
import scala.annotation.unroll

class Unrolled{
def foo(s: String, @unroll n: Int = 1, b: Boolean = true, l: Long = 0)(f: String => String) = f(s + n + b + l)
def foo(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0)(f: String => String) = f(s + n + b + l)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ object UnrollTestPlatformSpecific{
)

assert(
cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String]).invoke(instance, "hello", 2: Integer, identity[String](_)) ==
"hello2true0"
scala.util.Try(cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String])).isFailure
)
assert(
cls.getMethod("foo", classOf[String], classOf[Int], classOf[Boolean], classOf[String => String])
Expand Down
2 changes: 1 addition & 1 deletion unroll/tests/genericMethod/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package unroll
import scala.annotation.unroll

class Unrolled{
def foo[T](s: T, @unroll n: Int = 1, b: Boolean = true, l: Long = 0) = s.toString + n + b + l
def foo[T](s: T, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0) = s.toString + n + b + l
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ object UnrollTestPlatformSpecific{
"hello1true0"
)

assert(
cls.getMethod("foo", classOf[Object], classOf[Int]).invoke(instance, "hello", 2: Integer) ==
"hello2true0"
)
assert(scala.util.Try(cls.getMethod("foo", classOf[Object], classOf[Int])).isFailure)
assert(
cls.getMethod("foo", classOf[Object], classOf[Int], classOf[Boolean])
.invoke(instance, "hello", 2: Integer, java.lang.Boolean.FALSE) ==
Expand Down
2 changes: 1 addition & 1 deletion unroll/tests/methodWithImplicit/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package unroll
import scala.annotation.unroll

class Unrolled{
def foo(s: String, @unroll n: Int = 1, b: Boolean = true, l: Long = 0)(implicit f: String => String) = f(s + n + b + l)
def foo(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0)(implicit f: String => String) = f(s + n + b + l)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ object UnrollTestPlatformSpecific{
"hello1true0"
)

assert(
cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String]).invoke(instance, "hello", 2: Integer, identity[String](_)) ==
"hello2true0"
)
assert(scala.util.Try(cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String])).isFailure)

assert(
cls.getMethod("foo", classOf[String], classOf[Int], classOf[Boolean], classOf[String => String])
.invoke(instance, "hello", 2: Integer, java.lang.Boolean.FALSE, identity[String](_)) ==
Expand Down
2 changes: 1 addition & 1 deletion unroll/tests/objectMethod/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package unroll
import scala.annotation.unroll

object Unrolled{
def foo(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0) = s + n + b + l
def foo(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0) = s + n + b + l
}


Expand Down
2 changes: 1 addition & 1 deletion unroll/tests/primaryConstructor/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package unroll

import scala.annotation.unroll

class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0){
class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0){
def foo = s + n + b + l
}

Expand Down
2 changes: 1 addition & 1 deletion unroll/tests/secondaryConstructor/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import scala.annotation.unroll
class Unrolled() {
var foo = ""

def this(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0) = {
def this(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0) = {
this()
foo = s + n + b + l
}
Expand Down
2 changes: 1 addition & 1 deletion unroll/tests/traitMethod/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package unroll
import scala.annotation.unroll

trait Unrolled{
def foo(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0) = s + n + b + l
def foo(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0) = s + n + b + l
}

object Unrolled extends Unrolled

0 comments on commit 56b1757

Please sign in to comment.