From 56b1757cea21783378be536417bbf749aba207c9 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sat, 17 Feb 2024 10:52:14 +0800 Subject: [PATCH] require per-parameter unroll annotation --- build.sc | 10 ++-- unroll/plugin/src-2/UnrollPhaseScala2.scala | 41 +++++++------ unroll/plugin/src-3/UnrollPhaseScala3.scala | 58 ++++++++++--------- unroll/tests/caseclass/v2/src/Unrolled.scala | 1 + unroll/tests/caseclass/v3/src/Unrolled.scala | 5 +- .../tests/classMethod/v3/src/Unrolled.scala | 7 +-- .../src-jvm/UnrollTestPlatformSpecific.scala | 8 +-- .../tests/curriedMethod/v3/src/Unrolled.scala | 2 +- .../src-jvm/UnrollTestPlatformSpecific.scala | 3 +- .../tests/genericMethod/v3/src/Unrolled.scala | 2 +- .../src-jvm/UnrollTestPlatformSpecific.scala | 5 +- .../methodWithImplicit/v3/src/Unrolled.scala | 2 +- .../src-jvm/UnrollTestPlatformSpecific.scala | 6 +- .../tests/objectMethod/v3/src/Unrolled.scala | 2 +- .../primaryConstructor/v3/src/Unrolled.scala | 2 +- .../v3/src/Unrolled.scala | 2 +- .../tests/traitMethod/v3/src/Unrolled.scala | 2 +- 17 files changed, 79 insertions(+), 79 deletions(-) diff --git a/build.sc b/build.sc index 17beb4e..dc7a805 100644 --- a/build.sc +++ b/build.sc @@ -67,8 +67,8 @@ trait UnrollModule extends Cross.Module[String]{ "primaryConstructor", "secondaryConstructor", "caseclass", - "abstractTraitMethod", - "abstractClassMethod" +// "abstractTraitMethod", +// "abstractClassMethod" ) @@ -146,15 +146,15 @@ trait UnrollModule extends Cross.Module[String]{ } object jvm extends InnerScalaModule with ComparativePlatformScalaModule{ - def runClasspath = Seq(upstreamTest.jvm.test.compile().classes, upstream.jvm.compile().classes) + def runClasspath = super.runClasspath() ++ Seq(upstreamTest.jvm.test.compile().classes, upstream.jvm.compile().classes) } object js extends InnerScalaJsModule with ComparativePlatformScalaModule{ - def runClasspath = Seq(upstreamTest.js.test.compile().classes, upstream.js.compile().classes) + def runClasspath = super.runClasspath() ++ Seq(upstreamTest.js.test.compile().classes, upstream.js.compile().classes) } object native extends InnerScalaNativeModule with ComparativePlatformScalaModule{ - def runClasspath = Seq(upstreamTest.native.test.compile().classes, upstream.native.compile().classes) + def runClasspath = super.runClasspath() ++ Seq(upstreamTest.native.test.compile().classes, upstream.native.compile().classes) } } diff --git a/unroll/plugin/src-2/UnrollPhaseScala2.scala b/unroll/plugin/src-2/UnrollPhaseScala2.scala index 4bd8b02..3da182e 100644 --- a/unroll/plugin/src-2/UnrollPhaseScala2.scala +++ b/unroll/plugin/src-2/UnrollPhaseScala2.scala @@ -19,8 +19,10 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT new UnrollTransformer(unit) } - def findUnrollAnnotation(params: Seq[Symbol]): Int = { - params.toList.indexWhere(_.annotations.exists(_.tpe =:= typeOf[scala.annotation.unroll])) + def findUnrollAnnotations(params: Seq[Symbol]): Seq[Int] = { + params.toList.zipWithIndex.collect { + case (v, i) if v.annotations.exists(_.tpe =:= typeOf[scala.annotation.unroll]) => i + } } def copyValDef(vd: ValDef) = { @@ -75,8 +77,8 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT val forwardedValueParams = firstParamList.take(paramIndex).map(p => Ident(p.name).set(p.symbol)) - val defaultCalls = for (p <- Range(paramIndex, firstParamList.size)) yield { - val mangledName = defdef.name.toString + "$default$" + (p + 1) + val defaultCalls = Range(paramIndex, firstParamList.size).map{n => + val mangledName = defdef.name.toString + "$default$" + (n + 1) val defaultOwner = if (defdef.symbol.isConstructor) implDef.symbol.companionModule @@ -132,17 +134,10 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT forwarderDef.substituteSymbols(fromSyms, toSyms).asInstanceOf[DefDef] } - def generateDefForwarders(implDef: ImplDef, defdef: DefDef, startParamIndex: Int) = defdef.vparamss match { - case Nil => Nil - case firstParamList :: otherParamLists => - for (paramIndex <- Range(startParamIndex, firstParamList.length).toList) yield { - generateSingleForwarder(implDef, defdef, paramIndex, firstParamList, otherParamLists) - } - } class UnrollTransformer(unit: global.CompilationUnit) extends TypingTransformer(unit) { - def generateDefForwarders2(implDef: ImplDef): List[List[DefDef]] = { + def generateDefForwarders(implDef: ImplDef): List[List[DefDef]] = { implDef.impl.body.collect{ case defdef: DefDef => val annotatedOpt = @@ -160,11 +155,19 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT // do not have companion class primary constructor symbols, so we just skip them here annotatedOpt.toList.flatMap{ annotated => try { - annotated.asMethod.paramss.take(1).flatMap{ firstParams => - findUnrollAnnotation(firstParams) match { - case -1 => Nil - case n => generateDefForwarders(implDef, defdef, n) - } + defdef.vparamss match { + case Nil => Nil + case firstParamList :: otherParamLists => + val annotations = findUnrollAnnotations(annotated.tpe.params) + for (paramIndex <- annotations) yield { + generateSingleForwarder( + implDef, + defdef, + paramIndex, + firstParamList, + otherParamLists + ) + } } }catch{case e: Throwable => throw new Exception( @@ -179,7 +182,7 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT override def transform(tree: global.Tree): global.Tree = { tree match{ case md: ModuleDef => - val allNewMethods = generateDefForwarders2(md).flatten + val allNewMethods = generateDefForwarders(md).flatten val classInfoType = md.symbol.moduleClass.info.asInstanceOf[ClassInfoType] val newClassInfoType = classInfoType.copy(decls = newScopeWith(allNewMethods.map(_.symbol) ++ classInfoType.decls:_*)) @@ -199,7 +202,7 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT ) ) case cd: ClassDef => - val allNewMethods = generateDefForwarders2(cd).flatten + val allNewMethods = generateDefForwarders(cd).flatten super.transform( treeCopy.ClassDef( cd, diff --git a/unroll/plugin/src-3/UnrollPhaseScala3.scala b/unroll/plugin/src-3/UnrollPhaseScala3.scala index cb65f23..76acd01 100644 --- a/unroll/plugin/src-3/UnrollPhaseScala3.scala +++ b/unroll/plugin/src-3/UnrollPhaseScala3.scala @@ -45,7 +45,8 @@ class UnrollPhaseScala3() extends PluginPhase { paramLists: List[ParamClause], firstValueParamClauseIndex: Int, paramIndex: Int, - isCaseApply: Boolean)(using Context) = { + isCaseApply: Boolean) + (using Context) = { def truncateMethodType0(tpe: Type): Type = { tpe match{ @@ -55,7 +56,6 @@ class UnrollPhaseScala3() extends PluginPhase { } val truncatedMethodType = truncateMethodType0(prevMethodType) - val forwarderDefSymbol = Symbols.newSymbol( defdef.symbol.owner, defdef.name, @@ -71,7 +71,7 @@ class UnrollPhaseScala3() extends PluginPhase { } } - val defaultCalls = for (n <- Range(paramIndex, paramLists(firstValueParamClauseIndex).size)) yield { + val defaultCalls = Range(paramIndex, paramLists(firstValueParamClauseIndex).size).map(n => if (defdef.symbol.isConstructor) { ref(defdef.symbol.owner.companionModule) .select(DefaultGetterName(defdef.name, n)) @@ -82,7 +82,7 @@ class UnrollPhaseScala3() extends PluginPhase { This(defdef.symbol.owner.asClass) .select(DefaultGetterName(defdef.name, n)) } - } + ) val allNewParamTrees = updated.zipWithIndex.map{case (ps, i) => @@ -115,14 +115,14 @@ class UnrollPhaseScala3() extends PluginPhase { newDefDef } - def generateFromProduct(startParamIndex: Int, paramCount: Int, defdef: DefDef)(using Context) = { + def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = { cpy.DefDef(defdef)( name = defdef.name, paramss = defdef.paramss, tpt = defdef.tpt, rhs = Match( ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")), - Range(startParamIndex, paramCount).toList.map { paramIndex => + startParamIndices.map { paramIndex => val Apply(select, args) = defdef.rhs CaseDef( Literal(Constant(paramIndex)), @@ -170,30 +170,34 @@ class UnrollPhaseScala3() extends PluginPhase { if (firstValueParamClauseIndex == -1) (None, Nil) else { val paramCount = annotated.paramSymss(firstValueParamClauseIndex).size - annotated + val startParamIndices = annotated .paramSymss(firstValueParamClauseIndex) - .indexWhere(_.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll")) match{ - case -1 => (None, Nil) - case startParamIndex => - if (isCaseFromProduct) { - (Some(defdef.symbol), Seq(generateFromProduct(startParamIndex, paramCount, defdef))) - } else { - ( - None, - for (paramIndex <- Range(startParamIndex, paramCount)) yield { - generateSingleForwarder( - defdef, - defdef.symbol.info, - defdef.paramss, - firstValueParamClauseIndex, - paramIndex, - isCaseApply - ) - } - ) + .zipWithIndex + .collect{ + case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll") => + i } + if (startParamIndices == Nil) (None, Nil) + else if (isCaseFromProduct) { + (Some(defdef.symbol), Seq(generateFromProduct(startParamIndices, paramCount, defdef))) + } else { + ( + None, + + for (paramIndex <- startParamIndices) yield { + generateSingleForwarder( + defdef, + defdef.symbol.info, + defdef.paramss, + firstValueParamClauseIndex, + paramIndex, + isCaseApply + ) + } + ) } - } + } + case _ => (None, Nil) } diff --git a/unroll/tests/caseclass/v2/src/Unrolled.scala b/unroll/tests/caseclass/v2/src/Unrolled.scala index 916c445..28de5b0 100644 --- a/unroll/tests/caseclass/v2/src/Unrolled.scala +++ b/unroll/tests/caseclass/v2/src/Unrolled.scala @@ -5,3 +5,4 @@ import scala.annotation.unroll case class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true){ def foo = s + n + b } + diff --git a/unroll/tests/caseclass/v3/src/Unrolled.scala b/unroll/tests/caseclass/v3/src/Unrolled.scala index be6a77d..cac6656 100644 --- a/unroll/tests/caseclass/v3/src/Unrolled.scala +++ b/unroll/tests/caseclass/v3/src/Unrolled.scala @@ -2,6 +2,9 @@ package unroll import scala.annotation.unroll -case class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0){ +case class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0){ def foo = s + n + b + l } + + + diff --git a/unroll/tests/classMethod/v3/src/Unrolled.scala b/unroll/tests/classMethod/v3/src/Unrolled.scala index 429c911..01bcb4b 100644 --- a/unroll/tests/classMethod/v3/src/Unrolled.scala +++ b/unroll/tests/classMethod/v3/src/Unrolled.scala @@ -3,11 +3,6 @@ package unroll import scala.annotation.unroll class Unrolled{ - def foo(s: String, @unroll n: Int = 1, b: Boolean = true, l: Long = 0) = s + n + b + l + def foo(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0) = s + n + b + l } - - - - - diff --git a/unroll/tests/classMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala b/unroll/tests/classMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala index 3994ffa..f4968db 100644 --- a/unroll/tests/classMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala +++ b/unroll/tests/classMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala @@ -10,10 +10,10 @@ object UnrollTestPlatformSpecific{ "hello1true0" ) - assert( - cls.getMethod("foo", classOf[String], classOf[Int]).invoke(instance, "hello", 2: Integer) == - "hello2true0" - ) + // Only generate unrolled methods for annotated params + // (b: Boolean) is not annotated so this method should not exist + assert(scala.util.Try(cls.getMethod("foo", classOf[String], classOf[Int])).isFailure) + assert( cls.getMethod("foo", classOf[String], classOf[Int], classOf[Boolean]) .invoke(instance, "hello", 2: Integer, java.lang.Boolean.FALSE) == diff --git a/unroll/tests/curriedMethod/v3/src/Unrolled.scala b/unroll/tests/curriedMethod/v3/src/Unrolled.scala index 2721273..e79ab28 100644 --- a/unroll/tests/curriedMethod/v3/src/Unrolled.scala +++ b/unroll/tests/curriedMethod/v3/src/Unrolled.scala @@ -3,7 +3,7 @@ package unroll import scala.annotation.unroll class Unrolled{ - def foo(s: String, @unroll n: Int = 1, b: Boolean = true, l: Long = 0)(f: String => String) = f(s + n + b + l) + def foo(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0)(f: String => String) = f(s + n + b + l) } diff --git a/unroll/tests/curriedMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala b/unroll/tests/curriedMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala index 2e909c4..6a75551 100644 --- a/unroll/tests/curriedMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala +++ b/unroll/tests/curriedMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala @@ -11,8 +11,7 @@ object UnrollTestPlatformSpecific{ ) assert( - cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String]).invoke(instance, "hello", 2: Integer, identity[String](_)) == - "hello2true0" + scala.util.Try(cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String])).isFailure ) assert( cls.getMethod("foo", classOf[String], classOf[Int], classOf[Boolean], classOf[String => String]) diff --git a/unroll/tests/genericMethod/v3/src/Unrolled.scala b/unroll/tests/genericMethod/v3/src/Unrolled.scala index e505a74..5cd0500 100644 --- a/unroll/tests/genericMethod/v3/src/Unrolled.scala +++ b/unroll/tests/genericMethod/v3/src/Unrolled.scala @@ -3,7 +3,7 @@ package unroll import scala.annotation.unroll class Unrolled{ - def foo[T](s: T, @unroll n: Int = 1, b: Boolean = true, l: Long = 0) = s.toString + n + b + l + def foo[T](s: T, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0) = s.toString + n + b + l } diff --git a/unroll/tests/genericMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala b/unroll/tests/genericMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala index 4462f2e..ac7711c 100644 --- a/unroll/tests/genericMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala +++ b/unroll/tests/genericMethod/v3/test/src-jvm/UnrollTestPlatformSpecific.scala @@ -10,10 +10,7 @@ object UnrollTestPlatformSpecific{ "hello1true0" ) - assert( - cls.getMethod("foo", classOf[Object], classOf[Int]).invoke(instance, "hello", 2: Integer) == - "hello2true0" - ) + assert(scala.util.Try(cls.getMethod("foo", classOf[Object], classOf[Int])).isFailure) assert( cls.getMethod("foo", classOf[Object], classOf[Int], classOf[Boolean]) .invoke(instance, "hello", 2: Integer, java.lang.Boolean.FALSE) == diff --git a/unroll/tests/methodWithImplicit/v3/src/Unrolled.scala b/unroll/tests/methodWithImplicit/v3/src/Unrolled.scala index 14a20ac..6f12381 100644 --- a/unroll/tests/methodWithImplicit/v3/src/Unrolled.scala +++ b/unroll/tests/methodWithImplicit/v3/src/Unrolled.scala @@ -3,7 +3,7 @@ package unroll import scala.annotation.unroll class Unrolled{ - def foo(s: String, @unroll n: Int = 1, b: Boolean = true, l: Long = 0)(implicit f: String => String) = f(s + n + b + l) + def foo(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0)(implicit f: String => String) = f(s + n + b + l) } diff --git a/unroll/tests/methodWithImplicit/v3/test/src-jvm/UnrollTestPlatformSpecific.scala b/unroll/tests/methodWithImplicit/v3/test/src-jvm/UnrollTestPlatformSpecific.scala index 2cb00b4..ffe9d75 100644 --- a/unroll/tests/methodWithImplicit/v3/test/src-jvm/UnrollTestPlatformSpecific.scala +++ b/unroll/tests/methodWithImplicit/v3/test/src-jvm/UnrollTestPlatformSpecific.scala @@ -10,10 +10,8 @@ object UnrollTestPlatformSpecific{ "hello1true0" ) - assert( - cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String]).invoke(instance, "hello", 2: Integer, identity[String](_)) == - "hello2true0" - ) + assert(scala.util.Try(cls.getMethod("foo", classOf[String], classOf[Int], classOf[String => String])).isFailure) + assert( cls.getMethod("foo", classOf[String], classOf[Int], classOf[Boolean], classOf[String => String]) .invoke(instance, "hello", 2: Integer, java.lang.Boolean.FALSE, identity[String](_)) == diff --git a/unroll/tests/objectMethod/v3/src/Unrolled.scala b/unroll/tests/objectMethod/v3/src/Unrolled.scala index b3e35c4..c77eb64 100644 --- a/unroll/tests/objectMethod/v3/src/Unrolled.scala +++ b/unroll/tests/objectMethod/v3/src/Unrolled.scala @@ -3,7 +3,7 @@ package unroll import scala.annotation.unroll object Unrolled{ - def foo(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0) = s + n + b + l + def foo(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0) = s + n + b + l } diff --git a/unroll/tests/primaryConstructor/v3/src/Unrolled.scala b/unroll/tests/primaryConstructor/v3/src/Unrolled.scala index ee8cbca..18499ea 100644 --- a/unroll/tests/primaryConstructor/v3/src/Unrolled.scala +++ b/unroll/tests/primaryConstructor/v3/src/Unrolled.scala @@ -2,7 +2,7 @@ package unroll import scala.annotation.unroll -class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0){ +class Unrolled(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0){ def foo = s + n + b + l } diff --git a/unroll/tests/secondaryConstructor/v3/src/Unrolled.scala b/unroll/tests/secondaryConstructor/v3/src/Unrolled.scala index 7063819..4eab130 100644 --- a/unroll/tests/secondaryConstructor/v3/src/Unrolled.scala +++ b/unroll/tests/secondaryConstructor/v3/src/Unrolled.scala @@ -5,7 +5,7 @@ import scala.annotation.unroll class Unrolled() { var foo = "" - def this(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0) = { + def this(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0) = { this() foo = s + n + b + l } diff --git a/unroll/tests/traitMethod/v3/src/Unrolled.scala b/unroll/tests/traitMethod/v3/src/Unrolled.scala index a80af93..085ecea 100644 --- a/unroll/tests/traitMethod/v3/src/Unrolled.scala +++ b/unroll/tests/traitMethod/v3/src/Unrolled.scala @@ -3,7 +3,7 @@ package unroll import scala.annotation.unroll trait Unrolled{ - def foo(s: String, n: Int = 1, @unroll b: Boolean = true, l: Long = 0) = s + n + b + l + def foo(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0) = s + n + b + l } object Unrolled extends Unrolled \ No newline at end of file