Skip to content

Commit

Permalink
cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
xuwei-k committed May 3, 2024
1 parent 8c7d96d commit 76872ea
Show file tree
Hide file tree
Showing 5 changed files with 471 additions and 36 deletions.
214 changes: 214 additions & 0 deletions core/src/main/scala-latest-js/scalameta_ast/MainCompat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,226 @@ import org.ekrich.config.ConfigFactory
import org.ekrich.config.ConfigRenderOptions
import org.scalafmt.config.ScalafmtConfig
import scala.annotation.nowarn
import scala.annotation.tailrec
import scala.meta.common.Convert
import scala.meta.parsers.Parse
import scala.meta.tokens.Token
import scala.scalajs.js
import scala.scalajs.js.JSON
import scala.scalajs.js.annotation._
import scala.util.control.NonFatal

case class WithPosResult(src: String, cursorValues: List[(String, Int)], tokenMap: List[WrappedToken])

case class Highlighted(prefix: String, current: String, suffix: String)

case class WrappedToken(token: Token, addedSpaceByScalafmt: Boolean) {
lazy val tokenSize: Int = token.end - token.start
}

trait MainCompat {

private def escape(html: String): String = {
html
.replaceAll("&", "&")
.replaceAll("<", "&lt;")
.replaceAll(">", "&gt;")
.replaceAll("\"", "&quot;")
.replaceAll("'", "&#039;")
}

@JSExport
@nowarn("msg=never used")
def rawWithPos(
src: String,
dialect: String,
scalafmtConfig: String,
line: Int,
column: Int,
): js.Object = {
try {
val output =
rawWithPos1(
src = src,
dialect = dialect,
scalafmtConfig = scalafmtConfig,
line = line,
column = column
) match {
case Right(x) =>
List(
"<span>",
escape(x.prefix),
"</span><span style='color: blue;'>",
escape(x.current),
"</span><span>",
escape(x.suffix),
"</span>"
).mkString("")
case Left(x) =>
x
}
new js.Object {
val ast: String = output
val astBuildMs: Double = 0.0 // TODO
}
} catch {
case e: Throwable =>
new js.Object {
val error = e
val errorString: String = e.toString
}
}
}

def rawWithPos1(
src: String,
dialect: String,
scalafmtConfig: String,
line: Int,
column: Int,
): Either[String, Highlighted] = {
val result = rawWithPos0(
src = src,
dialect = dialect,
scalafmtConfig = scalafmtConfig,
line = line,
column = column
)

result.cursorValues match {
case List((s, pos)) =>
@tailrec
def loop(n: Int, list: List[WrappedToken], acc: Int): Int = {
list match {
case x :: xs =>
if (n <= 0) {
acc
} else {
loop(if (x.addedSpaceByScalafmt) n else n - x.tokenSize, xs, acc + x.tokenSize)
}
case _ =>
sys.error(s"error ${n} ${acc}")
}
}

val newStartPos = loop(pos, result.tokenMap, 0)
val currentSizeWithSpace = loop(s.length, result.tokenMap.dropWhile(_.token.end < newStartPos), 0)

Right(
Highlighted(
result.src.take(newStartPos),
result.src.drop(newStartPos).take(currentSizeWithSpace),
result.src.drop(newStartPos + currentSizeWithSpace),
)
)
case values =>
if (values.isEmpty) {
println(s"not found")
} else {
println(s"multi values ${values}")
}
Left(result.src)
}
}

def rawWithPos0(
src: String,
dialect: String,
scalafmtConfig: String,
line: Int,
column: Int,
): WithPosResult = {
import scala.meta._
val convert = implicitly[Convert[String, Input]]
val main = new ScalametaAST
val dialects =
main.stringToDialects.getOrElse(
dialect, {
Console.err.println(s"invalid dialct ${dialect}")
main.dialectsDefault
}
)

val input = convert.apply(src)
val tree: Tree = main.loopParse(
input,
for {
x1 <- main.parsers
x2 <- dialects
} yield (x1, x2)
)
val res: String = runFormat(
source = tree.structure,
scalafmtConfig = hoconToMetaConfig(scalafmtConfig)
).result
val tokens =
implicitly[Parse[Term]].apply(Input.String(res), scala.meta.dialects.Scala3).get.tokens
val tokenMap: List[WrappedToken] = {
val head = WrappedToken(
token = tokens.head,
addedSpaceByScalafmt = false,
)

head +: tokens.lazyZip(tokens.drop(1)).map { (t1, t2) =>
WrappedToken(
token = t2,
addedSpaceByScalafmt = {
!t1.is[scala.meta.tokens.Token.Comma] && t2.is[scala.meta.tokens.Token.Whitespace]
}
)
}
}.toList
assert(tokenMap.size == tokens.size)
val cursorPos = {
if (src.isEmpty) {
Position.Range(input, 0, 0)
} else if (line >= src.linesIterator.size) {
Position.Range(input, src.length, src.length)
} else {
Position.Range(input, line, column, line, column)
}
}.start

val t1: List[Tree] = tree.collect {
case x if (x.pos.start <= cursorPos && cursorPos <= x.pos.end) && ((x.pos.end - x.pos.start) >= 1) =>
x
}

implicit class ListOps[A](xs: List[A]) {
def minValues[B: Ordering](f: A => B): List[A] = {
xs.groupBy(f).minBy(_._1)._2
}
}

val t2 = if (t1.size > 1) {
val ss: List[Tree] = t1.minValues(t => t.pos.end - t.pos.start)
if (ss.isEmpty) {
t1
} else {
if (ss.size > 1) {
ss.minValues(_.structure.length) match {
case Nil => ss
case aa => aa
}
} else {
ss
}
}
} else {
t1
}

val result: List[(String, Int)] = t2.flatMap { cursorTree =>
val current = cursorTree.structure
val currentSize = current.length
tree.structure.sliding(currentSize).zipWithIndex.filter(_._1 == current).map(_._2).map { pos =>
(current, pos)
}
}
WithPosResult(res, result, tokenMap)
}

def runFormat(source: String, scalafmtConfig: Conf): Output[String] = {
ScalametaAST.stopwatch {
runFormat(
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/scalameta_ast/ScalametaAST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ object ScalametaAST {
}

class ScalametaAST {
private val dialectsDefault = List(dialects.Scala213Source3, dialects.Scala3)
private val stringToDialects: Map[String, List[Dialect]] = {
val dialectsDefault = List(dialects.Scala213Source3, dialects.Scala3)
val stringToDialects: Map[String, List[Dialect]] = {
import dialects._
Map(
"Auto" -> dialectsDefault,
Expand All @@ -38,7 +38,7 @@ class ScalametaAST {
)
.toMap
}
private val parsers: List[Parse[Tree]] = List(
val parsers: List[Parse[Tree]] = List(
implicitly[Parse[Stat]],
implicitly[Parse[Source]],
).map(_.asInstanceOf[Parse[Tree]])
Expand Down Expand Up @@ -75,7 +75,7 @@ class ScalametaAST {
).distinct.sortBy(_.getName)

@tailrec
private def loopParse(input: Input, xs: List[(Parse[Tree], Dialect)]): Tree = {
final def loopParse(input: Input, xs: List[(Parse[Tree], Dialect)]): Tree = {
(xs: @unchecked) match {
case (parse, dialect) :: t1 :: t2 =>
parse.apply(input, dialect) match {
Expand Down
125 changes: 125 additions & 0 deletions core/src/test/scala-latest-js/scalameta_ast/MainSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package scalameta_ast

import org.scalactic.source.Position
import org.scalatest.freespec.AnyFreeSpec

class MainSpec extends AnyFreeSpec {
"rawWithPos" - {
"empty" in {
val actual: WithPosResult = Main.rawWithPos0(
src = "",
dialect = "Scala3",
scalafmtConfig = "",
line = 0,
column = 0,
)
assert(actual.src == "Source(Nil)\n")
assert(actual.cursorValues == Nil)
}

"test 1" in {
def check(pos: Int): List[(String, Int)] = {
val lines = Seq(
"""class A {""",
""" def a(x: Y): Z =""",
""" b""",
"""}""",
)
val src = lines.mkString("\n")
assert(0 <= pos && pos < src.length)
if (pos == 0) {
Main.rawWithPos0(
src = src,
dialect = "Scala3",
scalafmtConfig = "",
line = 0,
column = 0,
)
} else {
val (lineSrc, sum, lineNumber) = lines.zipWithIndex.map { case (line, index) =>
(line, lines.take(index + 1).map(_.length).sum, index)
}.dropWhile(_._2 < pos).headOption.getOrElse((lines.last, src.length, lines.size))
val column = pos - sum + lineSrc.length
Main.rawWithPos0(
src = src,
dialect = "Scala3",
scalafmtConfig = "",
line = lineNumber,
column = column,
)
}
}.cursorValues

def checkClass(pos: Int)(implicit p: Position) = {
check(pos) match {
case List((t, _)) =>
assert(t.startsWith("Defn.Class(Nil"), pos)
case other =>
assert(false, other)
}
}

def checkTemplate(pos: Int)(implicit p: Position) = {
check(pos) match {
case List((t, _)) =>
assert(t.startsWith("Template(Nil, Nil, "), pos)
case other =>
assert(false, other)
}
}

def checkDef(pos: Int)(implicit p: Position) = {
check(pos) match {
case List((t, _)) =>
assert(t.startsWith("""Defn.Def(Nil, Term.Name("a"),"""), pos)
case other =>
assert(false, other)
}
}

(0 to 5).foreach { pos =>
checkClass(pos)
}
(6 to 7).foreach { pos =>
assert(check(pos) == List(("""Type.Name("A")""", 16)), pos)
}
(8 to 10).foreach { pos =>
checkTemplate(pos)
}
(11 to 14).foreach { pos =>
checkDef(pos)
}
(15 to 16).foreach { pos =>
assert(check(pos) == List(("""Term.Name("a")""", 165)), pos)
}
(17 to 18).foreach { pos =>
assert(check(pos) == List(("""Term.Name("x")""", 276)), pos)
}
assert(check(19) == List(("""Term.Param(Nil, Term.Name("x"), Some(Type.Name("Y")), None)""", 260)))
(20 to 21).foreach { pos =>
assert(check(pos) == List(("""Type.Name("Y")""", 297)), pos)
}
assert(
check(22) == List(
(
"""Term.ParamClause(List(Term.Param(Nil, Term.Name("x"), Some(Type.Name("Y")), None)), None)""",
238
)
)
)
checkDef(23)
(24 to 25).foreach { pos =>
assert(check(pos) == List(("""Type.Name("Z")""", 337)), pos)
}
(26 to 30).foreach { pos =>
checkDef(pos)
}
(31 to 32).foreach { pos =>
assert(check(pos) == List(("""Term.Name("b")""", 354)), pos)
}
(33 to 35).foreach { pos =>
checkTemplate(pos)
}
}
}
}
Loading

0 comments on commit 76872ea

Please sign in to comment.