Skip to content

Commit

Permalink
Overhaul and simplify TPrint implementation (#72)
Browse files Browse the repository at this point in the history
Fixes com-lihaoyi/Ammonite#221 
Fixes com-lihaoyi/Ammonite#629 
Fixes com-lihaoyi/Ammonite#670 
Fixes #45 
Fixes #44

The old TPrint implementation did a clever thing where it allowed a user to over-ride the TPrinting of a given type by providing an appropriate implicit. While that worked in most cases, it was fiendishly complex, and the intricate nesting of implicit resolution and macro resolution ended up providing and endless source of hard to resolve bugs.

This new implementation is much simpler and less flexible: we simply walk the type data structure in the macro, and spit out a colored `fansi.Str` with the type names hard-coded to `fansi.Green`. The only runtime support necessary is in the `def recolor` function, which parses the incoming `fansi.Str` and replaces the hardcoded `fansi.Green` colors with whatever is specified by the implicit `TPrintColors`. As implicits cannot be used to override tprinting anymore, we now have hardcoded support for tprinting functions and tuples.

While the old macro generated a complex tree of Scala function calls that is evaluated to generate the output `fansi.Str` at runtime, the new macro simply spits out a single `fansi.Str` that is serialized into a `java.lang.String` and deserialized back into a `fansi.Str` for usage at runtime. We propagate a `WrapType` enumeration up the recursion, to help the callers decide if they need to wrap things in parens or not.

This gives up a bit of flexibility, but AFAIK nobody was really using that flexibility anyway. In exchange, we fix a whole bunch of long-standing bugs, and have a drastically simpler implementation.

The fixed bugs are covered by regression unit tests added to `TPrintTests.scala`. All existing tests also pass, so hopefully that'll catch any potential regressions. There's probably more bugs where we're not properly setting or handling the `WrapType`, but exhaustively testing/surfacing/fixing all of those is beyond the scope of this PR. For now, I just kept the current set of tests passing.

Managed to get the Scala3 side working. I didn't realize how half-baked the Scala3 implementation of TPrint is; so much of the Scala2 functionality just isn't implemented and doesn't work. Nevertheless, fixing that is beyond the scope of this PR. I just kept it green with the existing set of green tests passing (except for the custom tprinter test, which is no longer applicable)

Review by @lolgab.
  • Loading branch information
lihaoyi authored Dec 6, 2021
1 parent 6a270d1 commit 7fead8d
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 327 deletions.
35 changes: 1 addition & 34 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.github.lolgab.mill.mima._

val dottyVersions = sys.props.get("dottyVersion").toList

val scalaVersions = "2.12.13" :: "2.13.4" :: "2.11.12" :: "3.0.0" :: dottyVersions
val scalaVersions = "2.12.13" :: "2.13.4" :: "2.11.12" :: "3.0.2" :: dottyVersions
val scala2Versions = scalaVersions.filter(_.startsWith("2."))

val scalaJSVersions = for {
Expand Down Expand Up @@ -64,39 +64,6 @@ trait PPrintMainModule extends CrossScalaModule {
)
)
)
def generatedSources = T{
val dir = T.ctx().dest
val file = dir/"pprint"/"TPrintGen.scala"

val typeGen = for(i <- 2 to 22) yield {
val ts = (1 to i).map("T" + _).mkString(", ")
val tsBounded = (1 to i).map("T" + _ + ": Type").mkString(", ")
val tsGet = (1 to i).map("get[T" + _ + "](cfg)").mkString(" + \", \" + ")
s"""
implicit def F${i}TPrint[$tsBounded, R: Type]: Type[($ts) => R] = make[($ts) => R](cfg =>
"(" + $tsGet + ") => " + get[R](cfg)
)
implicit def T${i}TPrint[$tsBounded]: Type[($ts)] = make[($ts)](cfg =>
"(" + $tsGet + ")"
)
"""
}
val output = s"""
package pprint
trait TPrintGen[Type[_], Cfg]{
def make[T](f: Cfg => String): Type[T]
def get[T: Type](cfg: Cfg): String
implicit def F0TPrint[R: Type]: Type[() => R] = make[() => R](cfg => "() => " + get[R](cfg))
implicit def F1TPrint[T1: Type, R: Type]: Type[T1 => R] = {
make[T1 => R](cfg => get[T1](cfg) + " => " + get[R](cfg))
}
${typeGen.mkString("\n")}
}
""".stripMargin
os.write(file, output, createFolders = true)
Seq(PathRef(file))
}

}


Expand Down
219 changes: 120 additions & 99 deletions pprint/src-2/TPrintImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,45 @@ trait TPrintLowPri{
implicit def default[T]: TPrint[T] = macro TPrintLowPri.typePrintImpl[T]
}
object TPrintLowPri{
sealed trait WrapType
object WrapType{
case object NoWrap extends WrapType
case object Infix extends WrapType
case object Tuple extends WrapType
}
def typePrintImpl[T: c.WeakTypeTag](c: Context): c.Expr[TPrint[T]] = {
// Used to provide "empty string" values in quasiquotes

import c.universe._
val s = ""
val tpe = weakTypeOf[T]
val rendered = typePrintImplRec(c)(tpe, rightMost = true).render
val res = c.Expr[TPrint[T]](
q"_root_.pprint.TPrint.recolor(_root_.fansi.Str($rendered))"
)
res
}

val functionTypes = Range.inclusive(0, 22).map(i => s"scala.Function$i").toSet
val tupleTypes = Range.inclusive(0, 22).map(i => s"scala.Tuple$i").toSet

def typePrintImplRec[T](c: Context)(tpe: c.Type, rightMost: Boolean): fansi.Str = {
typePrintImplRec0(c)(tpe, rightMost)._1
}
def typePrintImplRec0[T](c: Context)(tpe: c.Type, rightMost: Boolean): (fansi.Str, WrapType) = {
import c.universe._
def printSymString(s: Symbol) =
if (s.name.decodedName.toString.startsWith("_$")) "_"
else s.name.decodedName.toString.stripSuffix(".type")

def literalColor(s: Tree) = {
q"$cfgSym.typeColor($s).render"
}
def printSym(s: Symbol): Tree = {
literalColor(q"${printSymString(s)}")
}
def literalColor(s: fansi.Str): fansi.Str = fansi.Color.Green(s)
def printSym(s: Symbol): fansi.Str = literalColor(printSymString(s))

def printSymFull(s: Symbol): Tree = {
if (lookup(s)) printSym(s)
else q"""${printSymFull(s.owner)} + "." + ${printSym(s)}"""

def printSymFull(s: Symbol): fansi.Str = {
if (lookup(s)) printSym(s)
else printSymFull(s.owner) ++ "." ++ printSym(s)
}

/**
* Looks up a symbol in the enclosing scope and returns
* whether it exists in scope by the same name
Expand Down Expand Up @@ -55,64 +73,43 @@ object TPrintLowPri{
)}
}

def prefixFor(pre: Type, sym: Symbol): Tree = {
def prefixFor(pre: Type, sym: Symbol): fansi.Str = {
// Depending on what the prefix is, you may use `#`, `.`
// or even need to wrap the prefix in parentheses
val sep = pre match{
case x if x.toString.endsWith(".type") => q""" ${rec0(pre)} + "." """
case x: TypeRef => q""" ${literalColor(implicitRec(pre))} + "#" """
case x: SingleType => q""" ${literalColor(rec0(pre))} + "." """
case x: ThisType => q""" ${literalColor(rec0(pre))} + "." """
case x => q""" "(" + ${implicitRec(pre)} + ")#" """
case x if x.toString.endsWith(".type") => typePrintImplRec(c)(pre, false) ++ "."
case x: TypeRef => literalColor(typePrintImplRec(c)(pre, true)) ++ "#"
case x: SingleType => literalColor(typePrintImplRec(c)(pre, false)) ++ "."
case x: ThisType => literalColor(typePrintImplRec(c)(pre, false)) ++ "."
case x => fansi.Str("(") ++ typePrintImplRec(c)(pre, true) ++ ")#"
}

val prefix = if (!lookup(sym)) sep else q"$s"
q"$prefix + ${printSym(sym)}"
val prefix = if (!lookup(sym)) sep else fansi.Str("")
prefix ++ printSym(sym)
}


def printArgSyms(args: List[Symbol]): Tree = {
def added = args.map{x =>
val TypeBounds(lo, hi) = x.info
q""" ${printSym(x)} + ${printBounds(lo, hi)}"""
}.reduceLeft[Tree]((l, r) => q"""$l + ", " + $r""")
if (args == Nil) q"$s" else q""" "[" + $added + "]" """
}
def printArgs(args: List[Type]): Tree = {
def added = args.map(implicitRec(_))
.reduceLeft[Tree]((l, r) => q"""$l + ", " + $r""")
def printArgSyms(args: List[Symbol]): fansi.Str = {
def added = args
.map{x =>
val TypeBounds(lo, hi) = x.info
printSym(x) ++ printBounds(lo, hi)
}
.reduceLeft[fansi.Str]((l, r) => l ++ ", " ++ r)

if (args == Nil) q"$s" else q""" "[" + $added + "]" """
if (args == Nil) fansi.Str("") else fansi.Str("[") ++ added ++ "]"
}
def printArgs(args: List[Type]): fansi.Str = {
def added = args.map(typePrintImplRec(c)(_, true))
.reduceLeft[fansi.Str]((l, r) => l ++ ", " ++ r)


def implicitRec(tpe: Type) = {
val byName = (tpe: Type) match{
case t: TypeRef if t.toString.startsWith("=> ") => Some(t.args(0))
case _ => None
}

try {
// Make sure the type isn't higher-kinded or some other weird
// thing, and actually can fit inside the square brackets

byName match{
case Some(t) =>
c.typecheck(q"null.asInstanceOf[$tpe]")
q""" "=> " + _root_.pprint.TPrint.implicitly[$t].render($cfgSym) """
case _ =>
c.typecheck(q"null.asInstanceOf[$tpe]")
q""" _root_.pprint.TPrint.implicitly[$tpe].render($cfgSym) """
}

}catch{case e: TypecheckException =>
rec0(tpe)
}
if (args == Nil) fansi.Str("") else fansi.Str("[") ++ added ++ "]"
}

def printBounds(lo: Type, hi: Type) = {
val loTree = if (lo =:= typeOf[Nothing]) q"$s" else q""" " >: " + ${implicitRec(lo)} """
val hiTree = if (hi =:= typeOf[Any]) q"$s" else q""" " <: " + ${implicitRec(hi)} """
q"$loTree + $hiTree"
val loTree = if (lo =:= typeOf[Nothing]) fansi.Str("") else fansi.Str(" >: ") ++ typePrintImplRec(c)(lo, true)
val hiTree = if (hi =:= typeOf[Any]) fansi.Str("") else fansi.Str(" <: ") ++ typePrintImplRec(c)(hi, true)
loTree ++ hiTree
}

def showRefinement(quantified: List[Symbol]) = {
Expand All @@ -122,10 +119,10 @@ object TPrintLowPri{
case PolyType(typeParams, resultType) =>
val paramTree = printArgSyms(t.asInstanceOf[TypeSymbol].typeParams)
val resultBounds =
if (resultType =:= typeOf[Any]) q"$s"
else q""" " <: " + ${implicitRec(resultType)} """
if (resultType =:= typeOf[Any]) fansi.Str("")
else fansi.Str(" <: ") ++ typePrintImplRec(c)(resultType, true)

Some(q""" $paramTree + $resultBounds""")
Some(paramTree ++ resultBounds)
case TypeBounds(lo, hi)
if t.toString.contains("$") && lo =:= typeOf[Nothing] && hi =:= typeOf[Any] =>
None
Expand All @@ -141,67 +138,91 @@ object TPrintLowPri{
defs
)

q"""
"val " + ${literalColor(q"${t.name.toString.stripSuffix(".type")}")} +
": " + ${implicitRec(filtered)}
"""
fansi.Str("val ") ++ literalColor(t.name.toString.stripSuffix(".type")) ++
": " ++ typePrintImplRec(c)(filtered, true)
}else {
q""" "type " + ${printSym(t)} + $suffix """
fansi.Str("type ") ++ printSym(t) ++ suffix
}
}
if (stmts.length == 0) None
else Some(stmts.reduceLeft((l, r) => q""" $l + "; " + $r """))
else Some(stmts.reduceLeft((l, r) => l + "; " + r))
}
/**
* Decide how to pretty-print, based on the type.
*
* This is recursive, but we only rarely use direct recursion: more
* often, we'll use `implicitRec`, which goes through the normal
* implicit search channel and can thus
*/
def rec0(tpe: Type, end: Boolean = false): Tree = tpe match {

tpe match {
case TypeBounds(lo, hi) =>
val res = printBounds(lo, hi)
q""" "_" + $res """
(fansi.Str("_") ++ res, WrapType.NoWrap)
case ThisType(sym) =>
q"${printSymFull(sym)} + ${if(sym.isPackage || sym.isModuleClass) "" else ".this.type"}"
(printSymFull(sym) + (if(sym.isPackage || sym.isModuleClass) "" else ".this.type"), WrapType.NoWrap)

case SingleType(NoPrefix, sym) => q"${printSym(sym)} + ${if (end) ".type" else ""}"
case SingleType(pre, sym) => q"${prefixFor(pre, sym)} + ${if (end) ".type" else ""}"
case SingleType(NoPrefix, sym) => (printSym(sym) ++ (if (rightMost) ".type" else ""), WrapType.NoWrap)
case SingleType(pre, sym) => (prefixFor(pre, sym) ++ (if (rightMost) ".type" else ""), WrapType.NoWrap)
// Special-case operator two-parameter types as infix
case TypeRef(pre, sym, List(left, right))
if lookup(sym) && sym.name.encodedName.toString != sym.name.decodedName.toString =>

q"""${implicitRec(left)} + " " + ${printSym(sym)} + " " +${implicitRec(right)}"""
(
typePrintImplRec(c)(left, true) ++ " " ++ printSym(sym) ++ " " ++ typePrintImplRec(c)(right, true),
WrapType.Infix
)

case TypeRef(NoPrefix, sym, args) => q"${printSym(sym)} + ${printArgs(args)}"
case TypeRef(pre, sym, args) => q"${prefixFor(pre, sym)} + ${printArgs(args)}"
case et @ ExistentialType(quantified, underlying) =>
showRefinement(quantified) match{
case None => implicitRec(underlying)
case Some(block) => q"""${implicitRec(underlying)} + " forSome { " + $block + " }" """
case TypeRef(pre, sym, args) if functionTypes.contains(sym.fullName) =>
args match{
case Seq(r) => (fansi.Str("() => ") ++ typePrintImplRec(c)(r, true), WrapType.Infix)

case many =>
val (left, leftWrap) = typePrintImplRec0(c)(many.head, true)

if (many.size == 2 && leftWrap == WrapType.NoWrap){
(left ++ " => " ++ typePrintImplRec(c)(many(1), true), WrapType.Infix)
}else (
fansi.Str("(") ++
fansi.Str.join(
(left +: many.init.tail.map(typePrintImplRec(c)(_, true)))
.flatMap(Seq(_, fansi.Str(", "))).init:_*
) ++
") => " ++ typePrintImplRec(c)(many.last, true),
WrapType.Infix
)
}
case TypeRef(pre, sym, args) if tupleTypes.contains(sym.fullName) =>
(
fansi.Str("(") ++
fansi.Str.join(args.map(typePrintImplRec(c)(_, true)).flatMap(Seq(_, fansi.Str(", "))).init:_*) ++
")",
WrapType.Tuple
)

case TypeRef(NoPrefix, sym, args) => (printSym(sym) ++ printArgs(args), WrapType.NoWrap)
case TypeRef(pre, sym, args) =>
if (sym.fullName == "scala.<byname>") (fansi.Str("=> ") ++ typePrintImplRec(c)(args(0), true), WrapType.Infix)
else (prefixFor(pre, sym) ++ printArgs(args), WrapType.NoWrap)
case et @ ExistentialType(quantified, underlying) =>
(
showRefinement(quantified) match{
case None => typePrintImplRec(c)(underlying, true)
case Some(block) => typePrintImplRec(c)(underlying, true) ++ " forSome { " ++ block ++ " }"
},
WrapType.NoWrap
)
case AnnotatedType(annots, tp) =>
val mapped = annots.map(x => q""" " @" + ${implicitRec(x.tpe)}""")
.reduceLeft((x, y) => q"$x + $y")
q"${implicitRec(tp)} + $mapped"
val mapped = annots.map(x => " @" + typePrintImplRec(c)(x.tpe, true))
.reduceLeft((x, y) => x + y)

(
typePrintImplRec(c)(tp, true) + mapped,
WrapType.NoWrap
)
case RefinedType(parents, defs) =>
val pre =
if (parents.forall(_ =:= typeOf[AnyRef])) q""" "" """
else parents.map(implicitRec(_)).reduceLeft[Tree]((l, r) => q"""$l + " with " + $r""")
q"$pre + ${
if (defs.isEmpty) "" else "{" + defs.mkString(";") + "}"
}"
if (parents.forall(_ =:= typeOf[AnyRef])) ""
else parents
.map(typePrintImplRec(c)(_, true))
.reduceLeft[fansi.Str]((l, r) => l ++ " with " ++ r)
(pre + (if (defs.isEmpty) "" else "{" ++ defs.mkString(";") ++ "}"), WrapType.NoWrap)
case ConstantType(value) =>
q"$value.toString"
(value.toString, WrapType.NoWrap)
}
lazy val cfgSym = c.freshName[TermName](TermName("cfg"))
val res = c.Expr[TPrint[T]](q"""_root_.pprint.TPrint.lambda{
($cfgSym: _root_.pprint.TPrintColors) =>
${rec0(tpe, end = true)}
}""")
// println("RES " + res)
res
}

}
Loading

0 comments on commit 7fead8d

Please sign in to comment.