From 88ae477d1bf319d06f8f68b9fa2b81350a95d63e Mon Sep 17 00:00:00 2001 From: Amir Shaikhha Date: Fri, 30 Aug 2024 04:25:53 +0100 Subject: [PATCH 1/5] README includes papers --- README.md | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/README.md b/README.md index a7d07b05..197d8333 100644 --- a/README.md +++ b/README.md @@ -124,3 +124,102 @@ run interpret progs/tpch-interpreter q6.sdql ``` Or as a one-liner: `sbt "run interpret progs/tpch-interpreter q6.sdql"` + + +## Citing SDQL + +To cite SDQL, use the following BibTex: + +``` +@article{DBLP:journals/pacmpl/ShaikhhaHSO22, + author = {Amir Shaikhha and + Mathieu Huot and + Jaclyn Smith and + Dan Olteanu}, + title = {Functional collection programming with semi-ring dictionaries}, + journal = {Proc. {ACM} Program. Lang.}, + volume = {6}, + number = {{OOPSLA1}}, + pages = {1--33}, + year = {2022}, + url = {https://doi.org/10.1145/3527333}, + doi = {10.1145/3527333}, + timestamp = {Tue, 10 Jan 2023 16:19:51 +0100}, + biburl = {https://dblp.org/rec/journals/pacmpl/ShaikhhaHSO22.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` + +Depending on your usecase, the following papers are also relevant: + +* [SDQLpy](https://github.com/edin-dal/sdqlpy), a python embedding of SDQL for query processing + +``` +@inproceedings{DBLP:conf/cc/ShahrokhiS23, + author = {Hesam Shahrokhi and + Amir Shaikhha}, + editor = {Clark Verbrugge and + Ondrej Lhot{\'{a}}k and + Xipeng Shen}, + title = {Building a Compiled Query Engine in Python}, + booktitle = {Proceedings of the 32nd {ACM} {SIGPLAN} International Conference on + Compiler Construction, {CC} 2023, Montr{\'{e}}al, QC, Canada, + February 25-26, 2023}, + pages = {180--190}, + publisher = {{ACM}}, + year = {2023}, + url = {https://doi.org/10.1145/3578360.3580264}, + doi = {10.1145/3578360.3580264}, + timestamp = {Mon, 20 Feb 2023 14:39:08 +0100}, + biburl = {https://dblp.org/rec/conf/cc/ShahrokhiS23.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` + +* SDQLite, a subset of SDQL for (sparse) tensor algebra + +``` +@article{DBLP:journals/pacmmod/SchleichSS23, + author = {Maximilian Schleich and + Amir Shaikhha and + Dan Suciu}, + title = {Optimizing Tensor Programs on Flexible Storage}, + journal = {Proc. {ACM} Manag. Data}, + volume = {1}, + number = {1}, + pages = {37:1--37:27}, + year = {2023}, + url = {https://doi.org/10.1145/3588717}, + doi = {10.1145/3588717}, + timestamp = {Thu, 15 Jun 2023 21:57:49 +0200}, + biburl = {https://dblp.org/rec/journals/pacmmod/SchleichSS23.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` + +* Forward-mode Automatic Differentiation for SDQLite + +``` +@inproceedings{DBLP:conf/cgo/ShaikhhaHH24, + author = {Amir Shaikhha and + Mathieu Huot and + Shideh Hashemian}, + editor = {Tobias Grosser and + Christophe Dubach and + Michel Steuwer and + Jingling Xue and + Guilherme Ottoni and + ernando Magno Quint{\~{a}}o Pereira}, + title = {A Tensor Algebra Compiler for Sparse Differentiation}, + booktitle = {{IEEE/ACM} International Symposium on Code Generation and Optimization, + {CGO} 2024, Edinburgh, United Kingdom, March 2-6, 2024}, + pages = {1--12}, + publisher = {{IEEE}}, + year = {2024}, + url = {https://doi.org/10.1109/CGO57630.2024.10444787}, + doi = {10.1109/CGO57630.2024.10444787}, + timestamp = {Mon, 11 Mar 2024 13:45:28 +0100}, + biburl = {https://dblp.org/rec/conf/cgo/ShaikhhaHH24.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` From 1e0c1c9ed7fa8d558d6daf1e791766c65517eca3 Mon Sep 17 00:00:00 2001 From: Amir Shaikhha Date: Fri, 30 Aug 2024 05:21:13 +0100 Subject: [PATCH 2/5] Scala reformat sbt plugin added --- .scalafmt.conf | 12 +- project/plugins.sbt | 1 + .../scala/sdql/analysis/TypeInference.scala | 80 +++--- src/main/scala/sdql/backend/CppCodegen.scala | 242 +++++++++--------- src/main/scala/sdql/backend/CppCompile.scala | 8 +- src/main/scala/sdql/backend/Interpreter.scala | 211 ++++++++------- src/main/scala/sdql/driver/Main.scala | 4 +- src/main/scala/sdql/frontend/Parser.scala | 179 ++++++------- src/main/scala/sdql/frontend/SourceCode.scala | 7 +- src/main/scala/sdql/frontend/package.scala | 2 +- src/main/scala/sdql/ir/Exp.scala | 115 +++++---- .../scala/sdql/ir/ExternalFunctions.scala | 6 +- src/main/scala/sdql/ir/SemiRing.scala | 26 +- src/main/scala/sdql/ir/Type.scala | 28 +- src/main/scala/sdql/ir/Value.scala | 16 +- src/main/scala/sdql/storage/FastScanner.scala | 13 +- src/main/scala/sdql/storage/Loader.scala | 18 +- .../scala/sdql/transformations/Rewriter.scala | 110 ++++---- 18 files changed, 543 insertions(+), 535 deletions(-) create mode 100644 project/plugins.sbt diff --git a/.scalafmt.conf b/.scalafmt.conf index 2e4c63e7..bb29ae25 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,12 +1,18 @@ maxColumn = 120 -align = most +preset = default +indent.defnSite = 2 +optIn.configStyleArguments = false +align.preset = most continuationIndent.defnSite = 2 assumeStandardLibraryStripMargin = true -docstrings = JavaDoc +docstrings.style = Asterisk lineEndings = preserve includeCurlyBraceInSelectChains = false -danglingParentheses = true +danglingParentheses.preset = true spaces { inImportCurlyBraces = true } optIn.annotationNewlines = true +runner.dialect = scala213source3 rewrite.rules = [SortImports, RedundantBraces] + +version=3.7.13 \ No newline at end of file diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 00000000..6b6db485 --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1 @@ +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") \ No newline at end of file diff --git a/src/main/scala/sdql/analysis/TypeInference.scala b/src/main/scala/sdql/analysis/TypeInference.scala index 6e7678ad..190bd3f5 100644 --- a/src/main/scala/sdql/analysis/TypeInference.scala +++ b/src/main/scala/sdql/analysis/TypeInference.scala @@ -16,10 +16,12 @@ object TypeInference { def run(e: Exp)(implicit ctx: Ctx): Type = e match { case Sum(k, v, e1, e2) => sumInferTypeAndCtx(k, v, e1, e2)._1 - case IfThenElse(a, Const(false), Const(true)) => run(a) - case IfThenElse(_, - DictNode(Nil, _) | Update(DictNode(Nil, _), _, _), - DictNode(Nil, _) | Update(DictNode(Nil, _), _, _)) => + case IfThenElse(a, Const(false), Const(true)) => run(a) + case IfThenElse( + _, + DictNode(Nil, _) | Update(DictNode(Nil, _), _, _), + DictNode(Nil, _) | Update(DictNode(Nil, _), _, _) + ) => raise("both branches empty") case IfThenElse(_, DictNode(Nil, _) | Update(DictNode(Nil, _), _, _), e2) => run(e2) case IfThenElse(_, e1, DictNode(Nil, _) | Update(DictNode(Nil, _), _, _)) => run(e1) @@ -37,7 +39,7 @@ object TypeInference { case None => raise(s"unknown name: $name") } - case DictNode(Nil, _) => raise("Type inference needs backtracking to infer empty type { }") + case DictNode(Nil, _) => raise("Type inference needs backtracking to infer empty type { }") case DictNode(seq, hint) => DictType(seq.map(_._1).map(run).reduce(promote), seq.map(_._2).map(run).reduce(promote), hint) @@ -52,7 +54,7 @@ object TypeInference { case Some(idx) => attrs(idx).tpe case None => raise(attrs.map(_.name).mkString(s"$field not in: ", ", ", ".")) } - case tpe => raise(s"unexpected type: ${tpe.prettyPrint} in\n${e.prettyPrint}") + case tpe => raise(s"unexpected type: ${tpe.prettyPrint} in\n${e.prettyPrint}") } case Const(v) => @@ -67,60 +69,60 @@ object TypeInference { case Get(e1, e2) => run(e1) match { - case RecordType(attrs) => + case RecordType(attrs) => run(e2) match { case IntType => e2 match { case Const(v: Int) => attrs(v).tpe - case tpe => + case tpe => raise(s"expected ${Const.getClass.getSimpleName.init}, not ${tpe.simpleName}") } - case tpe => raise(s"expected ${IntType.getClass.getSimpleName.init}, not ${tpe.simpleName}") + case tpe => raise(s"expected ${IntType.getClass.getSimpleName.init}, not ${tpe.simpleName}") } case DictType(kType, vType, _) => run(e2) match { case tpe if tpe == kType => vType - case tpe => + case tpe => raise(s"can't index with ${tpe.simpleName} from ${DictType.getClass.getSimpleName.init}") } - case tpe => + case tpe => raise( s"expected ${RecordType.getClass.getSimpleName.init} or " + s"${DictType.getClass.getSimpleName.init}, not ${tpe.simpleName}" ) } - case External(ConstantString.SYMBOL, args) => + case External(ConstantString.SYMBOL, args) => val (str, maxLen) = args match { case Seq(Const(str: String), Const(maxLen: Int)) => (str, maxLen) } assert(maxLen == str.length + 1) StringType(Some(str.length)) case External(StrContains.SYMBOL | StrStartsWith.SYMBOL | StrEndsWith.SYMBOL | StrContainsN.SYMBOL, _) => BoolType - case External(SubString.SYMBOL, args) => + case External(SubString.SYMBOL, args) => val (str, start, end) = args match { case Seq(str, Const(start: Int), Const(end: Int)) => (str, start, end) } TypeInference.run(str) match { case StringType(None) => StringType(None) case StringType(Some(_)) => StringType(Some(end - start)) case t => raise(s"unexpected: ${t.prettyPrint}") } - case External(StrIndexOf.SYMBOL | FirstIndex.SYMBOL | LastIndex.SYMBOL | Year.SYMBOL, _) => IntType - case External(ParseDate.SYMBOL, _) => DateType - case External(Inv.SYMBOL, args) => + case External(StrIndexOf.SYMBOL | FirstIndex.SYMBOL | LastIndex.SYMBOL | Year.SYMBOL, _) => IntType + case External(ParseDate.SYMBOL, _) => DateType + case External(Inv.SYMBOL, args) => val arg = args match { case Seq(e) => e } run(arg) - case External(name @ Size.SYMBOL, args) => + case External(name @ Size.SYMBOL, args) => val arg = args match { case Seq(e) => e } run(arg) match { case DictType(_, vt, _) => vt - case tpe => + case tpe => raise(s"$name expect arg ${DictType.getClass.getSimpleName.init}, not ${tpe.simpleName}") } - case External(TopN.SYMBOL, _) => raise(s"unimplemented function name: ${TopN.SYMBOL}") - case External(CStore.SYMBOL, _) => raise(s"unimplemented function name: ${CStore.SYMBOL}") - case External(Log.SYMBOL, _) => raise(s"unimplemented function name: ${Log.SYMBOL}") - case External(name, _) => raise(s"unknown function name: $name") + case External(TopN.SYMBOL, _) => raise(s"unimplemented function name: ${TopN.SYMBOL}") + case External(CStore.SYMBOL, _) => raise(s"unimplemented function name: ${CStore.SYMBOL}") + case External(Log.SYMBOL, _) => raise(s"unimplemented function name: ${Log.SYMBOL}") + case External(name, _) => raise(s"unknown function name: $name") case LetBinding(Sym(name), e1, DictNode(Nil, _)) if name == resultName => run(e1) - case LetBinding(x, e1, e2) => + case LetBinding(x, e1, e2) => val t1 = TypeInference.run(e1) TypeInference.run(e2)(ctx ++ Map(x -> t1)) @@ -129,13 +131,13 @@ object TypeInference { case Load(_, rt: RecordType, skipCols) if isColumnStore(rt) && skipCols.isSetNode => val set = skipCols.toSkipColsSet RecordType(rt.attrs.filter(attr => !set.contains(attr.name))) - case Load(_, tp, _) => raise(s"unexpected: ${tp.prettyPrint}") + case Load(_, tp, _) => raise(s"unexpected: ${tp.prettyPrint}") } case Concat(e1, e2) => (run(e1), run(e2)) match { case (t1: RecordType, t2: RecordType) => t1.concat(t2) - case (v1, v2) => + case (v1, v2) => raise(s"`concat($v1,$v2)` needs records, but given `${v1.prettyPrint}`, `${v2.prettyPrint}`") } @@ -158,11 +160,11 @@ object TypeInference { // from e1 infer types of k, v val localCtx = ctx ++ (run(e1) match { case DictType(kType, vType, _) => Map(k -> kType, v -> vType) - case tpe => + case tpe => raise(s"assignment should be from ${DictType.getClass.getSimpleName.init} not ${tpe.simpleName}") }) // from types of k, v infer type of e2 - val tpe = run(e2)(localCtx) + val tpe = run(e2)(localCtx) (tpe, localCtx) } @@ -181,29 +183,29 @@ object TypeInference { val (e1, e2) = e match { case IfThenElse(a, b, Const(false)) => (a, b) // and case case IfThenElse(a, Const(true), b) => (a, b) // or case - case IfThenElse(cond, e1, e2) => + case IfThenElse(cond, e1, e2) => assert(run(cond) == BoolType) (e1, e2) - case Add(e1, e2) => (e1, e2) - case Mult(e1, e2) => (e1, e2) - case _ => raise(s"unhandled class: ${e.simpleName}") + case Add(e1, e2) => (e1, e2) + case Mult(e1, e2) => (e1, e2) + case _ => raise(s"unhandled class: ${e.simpleName}") } - val t1 = run(e1) - val t2 = run(e2) + val t1 = run(e1) + val t2 = run(e2) promote(t1, t2) } private def promote(t1: Type, t2: Type): Type = (t1, t2) match { - case (IntType, DateType) | (DateType, IntType) => IntType - case (IntType, RealType) | (RealType, IntType) => RealType + case (IntType, DateType) | (DateType, IntType) => IntType + case (IntType, RealType) | (RealType, IntType) => RealType case (DictType(kt1, vt1, hint1), DictType(kt2, vt2, hint2)) => assert(hint1 == hint2) DictType(promote(kt1, kt2), promote(vt1, vt2)) - case (DictType(kt, vt, hint), t) if t.isScalar => DictType(kt, promote(vt, t), hint) - case (t, DictType(kt, vt, hint)) if t.isScalar => DictType(kt, promote(vt, t), hint) - case (t1, t2) if t1 == t2 => t1 - case (t1, t2) => + case (DictType(kt, vt, hint), t) if t.isScalar => DictType(kt, promote(vt, t), hint) + case (t, DictType(kt, vt, hint)) if t.isScalar => DictType(kt, promote(vt, t), hint) + case (t1, t2) if t1 == t2 => t1 + case (t1, t2) => raise(s"can't promote types: ${t1.simpleName} ≠ ${t2.simpleName}") } } diff --git a/src/main/scala/sdql/backend/CppCodegen.scala b/src/main/scala/sdql/backend/CppCodegen.scala index 129bb202..1532712f 100644 --- a/src/main/scala/sdql/backend/CppCodegen.scala +++ b/src/main/scala/sdql/backend/CppCodegen.scala @@ -12,8 +12,8 @@ object CppCodegen { /** Generates C++ from an expression transformed to LLQL */ def apply(e: Exp, benchmarkRuns: Int = 0): String = { - val csvBody = cppCsvs(e) - val queryBody = run(e)(Map(), isTernary = false) + val csvBody = cppCsvs(e) + val queryBody = run(e)(Map(), isTernary = false) val benchStart = if (benchmarkRuns == 0) "" else @@ -21,7 +21,7 @@ object CppCodegen { |for (${cppType(IntType)} iter = 1; iter <= $benchmarkRuns; iter++) { |timer.Reset(); |""".stripMargin - val benchStop = + val benchStop = if (benchmarkRuns == 0) cppPrintResult(TypeInference(e)) else s"""timer.StoreElapsedTime(0); @@ -43,20 +43,20 @@ object CppCodegen { e match { case LetBinding(x @ Sym(name), e1, e2) => val isTernary = !cond(e1) { case _: Sum | _: Initialise => true } - val e1Cpp = e1 match { + val e1Cpp = e1 match { // codegen for loads was handled in a separate tree traversal - case _: Load => "" + case _: Load => "" case e1 @ External(ConstantString.SYMBOL, _) => s"const auto $name = ${run(e1)(typesCtx, isTernary)};" - case e1: Const => s"constexpr auto $name = ${run(e1)(Map(), isTernary = false)};" - case _ => + case e1: Const => s"constexpr auto $name = ${run(e1)(Map(), isTernary = false)};" + case _ => val isRetrieval = cond(e1) { case _: FieldNode | _: Get => true } def isDict = cond(TypeInference.run(e1)) { case _: DictType => true } val cppName = if (isRetrieval && isDict) s"&$name" else name val semicolon = if (cond(e1) { case _: Initialise => true }) "" else ";" s"auto $cppName = ${run(e1)(typesCtx, isTernary)}$semicolon" } - val e2Cpp = e2 match { + val e2Cpp = e2 match { case DictNode(Nil, _) => "" case _ => run(e2)(typesCtx ++ Map(x -> TypeInference.run(e1)), isTernary = false) } @@ -65,9 +65,9 @@ object CppCodegen { case Sum(k, v, e1, e2) => val (_, typesLocal) = TypeInference.sumInferTypeAndCtx(k, v, e1, e2) val body = CppCodegen.run(e2)(typesLocal, isTernary) - val head = e1 match { + val head = e1 match { case _: RangeNode => s"${cppType(IntType)} ${k.name} = 0; ${k.name} < ${CppCodegen.run(e1)}; ${k.name}++" - case _ => + case _ => val lhs = TypeInference.run(e1)(typesLocal) match { case DictType(_, _, _: PHmap) => s"&[${k.name}, ${v.name}]" case DictType(_, _, _: SmallVecDict) => s"&${k.name}" @@ -85,10 +85,10 @@ object CppCodegen { case e: IfThenElse if isTernary => ternary(e) case e: IfThenElse => default(e) - case Cmp(e1, e2: Sym, "∈") => + case Cmp(e1, e2: Sym, "∈") => TypeInference.run(e2) match { case _: DictType => dictCmpNil(e2, e1) - case tpe => + case tpe => raise( s"expression ${e2.simpleName} should be of type " + s"${DictType.getClass.getSimpleName.init} not ${tpe.prettyPrint}" @@ -100,7 +100,7 @@ object CppCodegen { case FieldNode(e1, field) => val tpe = (TypeInference.run(e1): @unchecked) match { case rt: RecordType => rt } - val idx = (tpe.indexOf(field): @unchecked) match { case Some(idx) => idx } + val idx = (tpe.indexOf(field): @unchecked) match { case Some(idx) => idx } s" /* $field */ std::get<$idx>(${run(e1)})" case Add(Promote(tp1, e1), Promote(tp2, e2)) => @@ -117,28 +117,26 @@ object CppCodegen { case Mult(_: Promote, _) | Mult(_, _: Promote) => raise(s"promotion not supported for ${Mult.getClass.getSimpleName.init}") - case Mult(e1, External(Inv.SYMBOL, Seq(e2))) => s"(${run(e1)} / ${run(e2)})" - case Mult(e1, e2) => s"(${run(e1)} * ${run(e2)})" + case Mult(e1, External(Inv.SYMBOL, Seq(e2))) => s"(${run(e1)} / ${run(e2)})" + case Mult(e1, e2) => s"(${run(e1)} * ${run(e2)})" case e: Neg => s"-${run(e)}" case Const(DateValue(v)) => val yyyymmdd = "^(\\d{4})(\\d{2})(\\d{2})$".r.findAllIn(v.toString).matchData.next() s"${yyyymmdd.group(1)}${yyyymmdd.group(2)}${yyyymmdd.group(3)}" - case Const(v: String) => s""""$v"""" - case Const(v) => v.toString + case Const(v: String) => s""""$v"""" + case Const(v) => v.toString case Sym(name) => name case DictNode(Nil, _) => "" case DictNode(seq, _) => - seq - .map({ - case (e1, e2) => - val e1Cpp = run(e1)(typesCtx, isTernary = true) - val e2Cpp = run(e2)(typesCtx, isTernary = true) - s"{$e1Cpp, $e2Cpp}" - }) + seq.map { case (e1, e2) => + val e1Cpp = run(e1)(typesCtx, isTernary = true) + val e2Cpp = run(e2)(typesCtx, isTernary = true) + s"{$e1Cpp, $e2Cpp}" + } .mkString(s"${cppType(TypeInference.run(e))}({", ", ", "})") case RecNode(values) => @@ -154,53 +152,53 @@ object CppCodegen { case External(ConstantString.SYMBOL, Seq(Const(str: String), Const(maxLen: Int))) => assert(maxLen == str.length + 1) s"""ConstantString("$str", $maxLen)""" - case External(StrContains.SYMBOL, Seq(str, subStr)) => + case External(StrContains.SYMBOL, Seq(str, subStr)) => val func = ((TypeInference.run(str), TypeInference.run(subStr)): @unchecked) match { - case (StringType(None), StringType(None)) => "find" - case (StringType(Some(_)), StringType(Some(_))) => "contains" + case (StringType(None), StringType(None)) => "find" + case (StringType(Some(_)), StringType(Some(_))) => "contains" case (StringType(None), StringType(Some(_))) | (StringType(Some(_)), StringType(None)) => raise(s"${StrContains.SYMBOL} doesn't support fixed and variable length strings together") } s"${CppCodegen.run(str)}.$func(${CppCodegen.run(subStr)})" - case External(StrStartsWith.SYMBOL, Seq(str, prefix)) => + case External(StrStartsWith.SYMBOL, Seq(str, prefix)) => val startsWith = (TypeInference.run(str): @unchecked) match { case StringType(None) => "starts_with" case StringType(Some(_)) => "startsWith" } s"${CppCodegen.run(str)}.$startsWith(${CppCodegen.run(prefix)})" - case External(StrEndsWith.SYMBOL, Seq(str, suffix)) => + case External(StrEndsWith.SYMBOL, Seq(str, suffix)) => val endsWith = (TypeInference.run(str): @unchecked) match { case StringType(None) => "ends_with" case StringType(Some(_)) => "endsWith" } s"${CppCodegen.run(str)}.$endsWith(${CppCodegen.run(suffix)})" - case External(SubString.SYMBOL, Seq(str, Const(start: Int), Const(end: Int))) => + case External(SubString.SYMBOL, Seq(str, Const(start: Int), Const(end: Int))) => val subStr = (TypeInference.run(str): @unchecked) match { case StringType(None) => "substr" case StringType(Some(_)) => s"substr<${end - start}>" } s"${CppCodegen.run(str)}.$subStr($start, $end)" - case External(StrIndexOf.SYMBOL, Seq(field: FieldNode, elem, from)) => + case External(StrIndexOf.SYMBOL, Seq(field: FieldNode, elem, from)) => assert(cond(TypeInference.run(field)) { case StringType(None) => true }) s"${CppCodegen.run(field)}.find(${CppCodegen.run(elem)}, ${CppCodegen.run(from)})" - case External(FirstIndex.SYMBOL, Seq(on, patt)) => + case External(FirstIndex.SYMBOL, Seq(on, patt)) => s"${CppCodegen.run(on)}.firstIndex(${CppCodegen.run(patt)})" - case External(LastIndex.SYMBOL, Seq(on, patt)) => + case External(LastIndex.SYMBOL, Seq(on, patt)) => s"${CppCodegen.run(on)}.lastIndex(${CppCodegen.run(patt)})" - case External(name @ Inv.SYMBOL, _) => + case External(name @ Inv.SYMBOL, _) => raise(s"$name should have been handled by ${Mult.getClass.getSimpleName.init}") - case External(Size.SYMBOL, Seq(arg)) => + case External(Size.SYMBOL, Seq(arg)) => TypeInference.run(arg) match { case _: DictType => s"${CppCodegen.run(arg)}.size()" case t => raise(s"unexpected: ${t.prettyPrint}") } - case External(name, _) => raise(s"unhandled function name: $name") + case External(name, _) => raise(s"unhandled function name: $name") case Concat(e1: RecNode, e2: RecNode) => run(e1.concat(e2)) case Concat(e1: Sym, e2: Sym) => s"std::tuple_cat(${run(e1)}, ${run(e2)})" case Concat(e1: Sym, e2: RecNode) => s"std::tuple_cat(${run(e1)}, ${run(e2)})" case Concat(e1: RecNode, e2: Sym) => s"std::tuple_cat(${run(e1)}, ${run(e2)})" - case Concat(e1, e2) => + case Concat(e1, e2) => val _ = TypeInference.run(e) raise( s"${Concat.getClass.getSimpleName} requires arguments " + @@ -239,7 +237,7 @@ object CppCodegen { case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") } - private def ternary(e: IfThenElse)(implicit typesCtx: TypesCtx) = e match { + private def ternary(e: IfThenElse)(implicit typesCtx: TypesCtx) = e match { case IfThenElse(cond, e1, e2) => val condBody = run(cond)(typesCtx, isTernary = true) val ifBody = run(e1)(typesCtx, isTernary = true) @@ -267,7 +265,7 @@ object CppCodegen { case _ => s"${run(e1)}.contains(${run(e2)})" } - private def cppLhsRhs(e: Exp, destination: Sym)(implicit typesCtx: TypesCtx) = { + private def cppLhsRhs(e: Exp, destination: Sym)(implicit typesCtx: TypesCtx) = { val (accessors, inner) = splitNested(e) val bracketed = cppAccessors(accessors)(typesCtx, isTernary = true) val lhs = s"${destination.name}$bracketed" @@ -275,79 +273,79 @@ object CppCodegen { (lhs, rhs) } private def cppAccessors(exps: Iterable[Exp])(implicit typesCtx: TypesCtx, isTernary: Boolean) = - exps.map(e => { s"[${CppCodegen.run(e)}]" }).mkString("") - private def splitNested(e: Exp): (Seq[Exp], Exp) = e match { + exps.map(e => s"[${CppCodegen.run(e)}]").mkString("") + private def splitNested(e: Exp): (Seq[Exp], Exp) = e match { case DictNode(Seq((k, v @ DictNode(_, _: PHmap | _: SmallVecDict | _: SmallVecDicts))), _) => val (lhs, rhs) = splitNested(v) (Seq(k) ++ lhs, rhs) - case DictNode(Seq((k, DictNode(Seq((rhs, Const(1))), _: Vec))), _) => (Seq(k), rhs) - case DictNode(Seq((k, rhs)), _) => (Seq(k), rhs) - case DictNode(map, _) if map.length != 1 => raise(s"unsupported: $e") - case _ => (Seq(), e) + case DictNode(Seq((k, DictNode(Seq((rhs, Const(1))), _: Vec))), _) => (Seq(k), rhs) + case DictNode(Seq((k, rhs)), _) => (Seq(k), rhs) + case DictNode(map, _) if map.length != 1 => raise(s"unsupported: $e") + case _ => (Seq(), e) } private def initialise(tpe: Type)(implicit agg: Aggregation, typesCtx: TypesCtx, isTernary: Boolean): String = tpe match { case DictType(_, _, PHmap(Some(e))) => CppCodegen.run(e) case DictType(_, _, PHmap(None) | _: SmallVecDict) => "{}" - case DictType(_, _, Vec(size)) => + case DictType(_, _, Vec(size)) => size match { case None => "" case Some(size) => (size + 1).toString } - case DictType(_, _, _: SmallVecDicts) => "" - case RecordType(attrs) => attrs.map(_.tpe).map(initialise).mkString(", ") - case BoolType => + case DictType(_, _, _: SmallVecDicts) => "" + case RecordType(attrs) => attrs.map(_.tpe).map(initialise).mkString(", ") + case BoolType => agg match { case SumAgg | MaxAgg => "false" case ProdAgg | MinAgg => "true" } - case RealType => + case RealType => agg match { case SumAgg | MaxAgg => "0.0" case ProdAgg => "1.0" case MinAgg => s"std::numeric_limits<${cppType(RealType)}>::max()" } - case IntType | DateType => + case IntType | DateType => agg match { case SumAgg | MaxAgg => "0" case ProdAgg => "1" case MinAgg => s"std::numeric_limits<${cppType(IntType)}>::max()" } - case StringType(None) => + case StringType(None) => agg match { case SumAgg | MaxAgg => "\"\"" case ProdAgg => raise("undefined") case MinAgg => s"MAX_STRING" } - case StringType(Some(_)) => raise("initialising VarChars shouldn't be needed") - case tpe => raise(s"unimplemented type: $tpe") + case StringType(Some(_)) => raise("initialising VarChars shouldn't be needed") + case tpe => raise(s"unimplemented type: $tpe") } private def cppType(tpe: Type, noTemplate: Boolean = false): String = tpe match { - case DictType(kt, vt, _: PHmap) => + case DictType(kt, vt, _: PHmap) => val template = if (noTemplate) "" else s"<${cppType(kt)}, ${cppType(vt)}>" s"phmap::flat_hash_map$template" - case DictType(kt, IntType, SmallVecDict(size)) => + case DictType(kt, IntType, SmallVecDict(size)) => val template = if (noTemplate) "" else s"<${cppType(kt)}, $size>" s"smallvecdict$template" case DictType(rt: RecordType, IntType, SmallVecDicts(size)) => val template = if (noTemplate) "" else s"<$size, ${recordParams(rt)}>" s"smallvecdicts$template" - case DictType(IntType, vt, _: Vec) => + case DictType(IntType, vt, _: Vec) => val template = if (noTemplate) "" else s"<${cppType(vt)}>" s"std::vector$template" - case rt: RecordType => + case rt: RecordType => val template = if (noTemplate) "" else s"<${recordParams(rt)}>" s"std::tuple$template" - case BoolType => "bool" - case RealType => "double" - case IntType | DateType => "int" - case StringType(None) => "std::string" - case StringType(Some(maxLen)) => s"VarChar<$maxLen>" - case tpe => raise(s"unimplemented type: $tpe") + case BoolType => "bool" + case RealType => "double" + case IntType | DateType => "int" + case StringType(None) => "std::string" + case StringType(Some(maxLen)) => s"VarChar<$maxLen>" + case tpe => raise(s"unimplemented type: $tpe") } - private def recordParams(rt: RecordType) = rt match { + private def recordParams(rt: RecordType) = rt match { case RecordType(attrs) => attrs.map(_.tpe).map(cppType(_)).mkString(", ") } @@ -359,86 +357,83 @@ object CppCodegen { // let same_varname = load[...]("foo.csv") // else // let same_varname = load[...]("bar.csv") - private def cppCsvs(e: Exp): String = { + private def cppCsvs(e: Exp): String = { val pathNameTypeSkip = iterExps(e).flatMap(extract).toSeq.distinct.sortBy(_._2) - val csvConsts = - pathNameTypeSkip.map({ case (path, name, _, _) => makeCsvConst(name, path) }).mkString("\n", "\n", "\n") - val tuples = pathNameTypeSkip - .map({ - case (_, name, recordType, skipCols) => - val init = makeTupleInit(name, recordType, skipCols) - s"auto ${name.toLowerCase} = ${cppType(recordType, noTemplate = true)}($init);\n" - }) + val csvConsts = + pathNameTypeSkip.map { case (path, name, _, _) => makeCsvConst(name, path) }.mkString("\n", "\n", "\n") + val tuples = pathNameTypeSkip.map { case (_, name, recordType, skipCols) => + val init = makeTupleInit(name, recordType, skipCols) + s"auto ${name.toLowerCase} = ${cppType(recordType, noTemplate = true)}($init);\n" + } .mkString("\n") Seq(csvConsts, tuples).mkString("\n") } - private def extract(e: Exp) = condOpt(e) { + private def extract(e: Exp) = condOpt(e) { case LetBinding(Sym(name), load @ Load(path, tp: RecordType, _), _) if TypeInference.isColumnStore(tp) => - val recordType = (load: @unchecked) match { case Load(_, recordType: RecordType, _) => recordType } + val recordType = (load: @unchecked) match { case Load(_, recordType: RecordType, _) => recordType } val skipCols: Set[String] = (load: @unchecked) match { case Load(_, _, skipCols) => skipCols.toSkipColsSet } (path, name, recordType, skipCols) } - private def makeCsvConst(name: String, path: String) = + private def makeCsvConst(name: String, path: String) = s"""const rapidcsv::Document ${name.toUpperCase}_CSV("../$path", NO_HEADERS, SEPARATOR);""" private def makeTupleInit(name: String, recordType: RecordType, skipCols: Set[String]) = { assert(recordType.attrs.last.name == "size") - val attrs = recordType.attrs + val attrs = recordType.attrs .dropRight(1) .map(attr => (attr.tpe: @unchecked) match { case DictType(IntType, vt, Vec(None)) => Attribute(attr.name, vt) }) - val readCols = attrs.zipWithIndex.filter { case (attr, _) => !skipCols.contains(attr.name) } - .map({ - case (Attribute(attr_name, tpe), i) => - s"/* $attr_name */" ++ (tpe match { - case DateType => - s"dates_to_numerics(" + s"${name.toUpperCase}_CSV.GetColumn<${cppType(StringType())}>($i)" + ")" - case StringType(Some(maxLen)) => - s"strings_to_varchars<$maxLen>(" + s"${name.toUpperCase}_CSV.GetColumn<${cppType(StringType())}>($i)" + ")" - case _ => - s"${name.toUpperCase}_CSV.GetColumn<${cppType(tpe)}>($i)" - }) - }) + val readCols = attrs.zipWithIndex.filter { case (attr, _) => !skipCols.contains(attr.name) }.map { + case (Attribute(attr_name, tpe), i) => + s"/* $attr_name */" ++ (tpe match { + case DateType => + s"dates_to_numerics(" + s"${name.toUpperCase}_CSV.GetColumn<${cppType(StringType())}>($i)" + ")" + case StringType(Some(maxLen)) => + s"strings_to_varchars<$maxLen>(" + s"${name.toUpperCase}_CSV.GetColumn<${cppType(StringType())}>($i)" + ")" + case _ => + s"${name.toUpperCase}_CSV.GetColumn<${cppType(tpe)}>($i)" + }) + } val readSize = if (skipCols.contains("size")) Seq() else Seq(s"/* size */static_cast<${cppType(IntType)}>(${name.toUpperCase}_CSV.GetRowCount())") (readCols ++ readSize).mkString(",\n") } - private def iterExps(e: Exp): Iterator[Exp] = + private def iterExps(e: Exp): Iterator[Exp] = Iterator(e) ++ ( e match { // 0-ary - case _: Sym | _: Const | _: Load => Iterator() + case _: Sym | _: Const | _: Load => Iterator() // 1-ary - case Neg(e) => iterExps(e) - case FieldNode(e, _) => iterExps(e) - case Promote(_, e) => iterExps(e) - case RangeNode(e) => iterExps(e) - case Unique(e) => iterExps(e) + case Neg(e) => iterExps(e) + case FieldNode(e, _) => iterExps(e) + case Promote(_, e) => iterExps(e) + case RangeNode(e) => iterExps(e) + case Unique(e) => iterExps(e) // 2-ary - case Add(e1, e2) => iterExps(e1) ++ iterExps(e2) - case Mult(e1, e2) => iterExps(e1) ++ iterExps(e2) - case Cmp(e1, e2, _) => iterExps(e1) ++ iterExps(e2) - case Sum(_, _, e1, e2) => iterExps(e1) ++ iterExps(e2) - case Get(e1, e2) => iterExps(e1) ++ iterExps(e2) - case Concat(e1, e2) => iterExps(e1) ++ iterExps(e2) - case LetBinding(_, e1, e2) => iterExps(e1) ++ iterExps(e2) + case Add(e1, e2) => iterExps(e1) ++ iterExps(e2) + case Mult(e1, e2) => iterExps(e1) ++ iterExps(e2) + case Cmp(e1, e2, _) => iterExps(e1) ++ iterExps(e2) + case Sum(_, _, e1, e2) => iterExps(e1) ++ iterExps(e2) + case Get(e1, e2) => iterExps(e1) ++ iterExps(e2) + case Concat(e1, e2) => iterExps(e1) ++ iterExps(e2) + case LetBinding(_, e1, e2) => iterExps(e1) ++ iterExps(e2) // 3-ary - case IfThenElse(e1, e2, e3) => iterExps(e1) ++ iterExps(e2) ++ iterExps(e3) + case IfThenElse(e1, e2, e3) => iterExps(e1) ++ iterExps(e2) ++ iterExps(e3) // n-ary case RecNode(values) => values.map(_._2).flatMap(iterExps) case DictNode(map, PHmap(Some(e))) => map.flatMap(x => iterExps(x._1) ++ iterExps(x._2)) ++ iterExps(e) case DictNode(map, _) => map.flatMap(x => iterExps(x._1) ++ iterExps(x._2)) case External(_, args) => args.flatMap(iterExps) // LLQL - case Initialise(_, e) => iterExps(e) - case Update(e, _, _) => iterExps(e) - case Modify(e, _) => iterExps(e) - case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") + case Initialise(_, e) => iterExps(e) + case Update(e, _, _) => iterExps(e) + case Modify(e, _) => iterExps(e) + case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") } ) - private def cppPrintResult(tpe: Type): String = tpe match { - case DictType(kt, vt, _: PHmap) => + private def cppPrintResult(tpe: Type): String = tpe match { + case DictType(kt, vt, _: PHmap) => s"""for (const auto &[key, val] : $resultName) { |$stdCout << ${_cppPrintResult(kt, "key")} << ":" << ${_cppPrintResult(vt, "val")} << std::endl; |}""".stripMargin @@ -447,7 +442,7 @@ object CppCodegen { val cond = vt match { case _: DictType => s"!$resultName[i].empty()" case _: RecordType => raise(s"Print not implemented for ${vt.simpleName} inside vector") - case t => + case t => assert(t.isScalar) s"$resultName[i] != 0" } @@ -456,24 +451,21 @@ object CppCodegen { |$stdCout << ${_cppPrintResult(kt, "i")} << ":" << ${_cppPrintResult(vt, s"$resultName[i]")} << std::endl; |} |}""".stripMargin - case _ => s"$stdCout << ${_cppPrintResult(tpe, resultName)} << std::endl;" + case _ => s"$stdCout << ${_cppPrintResult(tpe, resultName)} << std::endl;" } private def _cppPrintResult(tpe: Type, name: String) = tpe match { - case _: DictType => name // we currently don't pretty print inside nested dicts - case RecordType(Nil) => name + case _: DictType => name // we currently don't pretty print inside nested dicts + case RecordType(Nil) => name case RecordType(attrs) => - attrs.zipWithIndex - .map( - { - case (Attribute(_, tpe: RecordType), _) => raise(s"Nested ${tpe.simpleName} not supported") - case (Attribute(_, DateType), i) => s"print_date(std::get<$i>($name))" - case (Attribute(_, _), i) => s"std::get<$i>($name)" - } - ) + attrs.zipWithIndex.map { + case (Attribute(_, tpe: RecordType), _) => raise(s"Nested ${tpe.simpleName} not supported") + case (Attribute(_, DateType), i) => s"print_date(std::get<$i>($name))" + case (Attribute(_, _), i) => s"std::get<$i>($name)" + } .mkString(""""<" <<""", """ << "," << """, """<< ">"""") - case _ => + case _ => assert(tpe.isScalar) name } - private val stdCout = s"std::cout << std::setprecision (std::numeric_limits::digits10)" + private val stdCout = s"std::cout << std::setprecision (std::numeric_limits::digits10)" } diff --git a/src/main/scala/sdql/backend/CppCompile.scala b/src/main/scala/sdql/backend/CppCompile.scala index ae7e065a..e4f7904f 100644 --- a/src/main/scala/sdql/backend/CppCompile.scala +++ b/src/main/scala/sdql/backend/CppCompile.scala @@ -13,7 +13,7 @@ object CppCompile { def writeFormat(sdqlFilePath: String, cpp: String): Unit = { val noExtension = getNoExtension(sdqlFilePath) write(cppPath(noExtension), cpp) - val _ = inGeneratedDir(clangFormat(noExtension)).!! + val _ = inGeneratedDir(clangFormat(noExtension)).!! } private def compileRun(sdqlFilePath: String) = { @@ -42,7 +42,7 @@ object CppCompile { val path = Paths.get(generatedDir.toString, cmakeFileName) val _ = write(path, contents) } - private def cmakeContents(noExtensions: Seq[String]) = { + private def cmakeContents(noExtensions: Seq[String]) = { val init = s"""# auto-generated config - handy for Clion |cmake_minimum_required(VERSION $cmakeVersion) |project(generated) @@ -58,8 +58,8 @@ object CppCompile { |""".stripMargin noExtensions.map(noExtension => s"add_executable($noExtension.out $noExtension.cpp)").mkString(init, "\n", "") } - private val cmakeVersion = "3.28" - private val cmakeFileName = "CMakeLists.txt" + private val cmakeVersion = "3.28" + private val cmakeFileName = "CMakeLists.txt" def inGeneratedDir(seq: Seq[String]): ProcessBuilder = Process(seq, generatedDir) private val generatedDir = new java.io.File("generated") diff --git a/src/main/scala/sdql/backend/Interpreter.scala b/src/main/scala/sdql/backend/Interpreter.scala index 344198d8..7d9ed0d4 100644 --- a/src/main/scala/sdql/backend/Interpreter.scala +++ b/src/main/scala/sdql/backend/Interpreter.scala @@ -10,26 +10,26 @@ import scala.annotation.{ nowarn, tailrec } type Value = Any type Var = Sym type Ctx = Map[Var, Value] - def apply(e: Exp): Value = run(e)(Map()) - def run(e: Exp)(implicit ctx: Ctx): Value = e match { - case Const(v) => v - case RecNode(vals) => RecordValue(vals.map(x => x._1 -> run(x._2))) - case sym @ Sym(name) => + def apply(e: Exp): Value = run(e)(Map()) + def run(e: Exp)(implicit ctx: Ctx): Value = e match { + case Const(v) => v + case RecNode(vals) => RecordValue(vals.map(x => x._1 -> run(x._2))) + case sym @ Sym(name) => ctx.get(sym) match { case Some(v) => v case None => raise(s"Variable `$name` not in scope!") } - case DictNode(vals, _) => normalize(vals.map(x => run(x._1) -> run(x._2)).toMap) - case LetBinding(x, e1, e2) => + case DictNode(vals, _) => normalize(vals.map(x => run(x._1) -> run(x._2)).toMap) + case LetBinding(x, e1, e2) => val v1 = run(e1) run(e2)(ctx ++ Map(x -> v1)) - case Mult(e1, e2) => + case Mult(e1, e2) => val (v1, v2) = (run(e1), run(e2)) mult(v1, v2) - case Add(e1, e2) => + case Add(e1, e2) => val (v1, v2) = (run(e1), run(e2)) add(v1, v2) - case Neg(e1) => + case Neg(e1) => val v1 = run(e1) v1 match { case ZeroValue => ZeroValue @@ -44,8 +44,8 @@ import scala.annotation.{ nowarn, tailrec } case false | ZeroValue => run(e2) case _ => raise(s"`if($vcond)` not handled") } - case Cmp(e1, e2, op) => - val (v1, v2) = (run(e1), run(e2)) + case Cmp(e1, e2, op) => + val (v1, v2) = (run(e1), run(e2)) op match { case "==" => return equal(v1, v2) case "!=" => return !equal(v1, v2) @@ -55,10 +55,10 @@ import scala.annotation.{ nowarn, tailrec } def cmp(d1: Double, d2: Double): Int = { @tailrec def findPrec(d: Double, res: Double): Double = if (d <= 0) res else findPrec(d * 10 - (d * 10).toInt, res / 10) - val precision = findPrec(d2 - d2.toInt, 1) * 0.001 + val precision = findPrec(d2 - d2.toInt, 1) * 0.001 if (math.abs(d1 - d2) < precision) 0 else if (d1 < d2 + precision) -1 else 1 } - val res = (v1, v2) match { + val res = (v1, v2) match { case (r1: Int, r2: Int) => cmp(r1.toDouble, r2.toDouble) case (r1: Int, r2: Double) => cmp(r1.toDouble, r2) case (r1: Double, r2: Int) => cmp(r1, r2.toDouble) @@ -71,7 +71,7 @@ import scala.annotation.{ nowarn, tailrec } case "<=" => res == -1 || res == 0 case _ => raise(s"`$v1 $op $v2`'s operator not handled") } - case FieldNode(e1, f) => + case FieldNode(e1, f) => val v1 = run(e1) v1 match { case RecordValue(vals) => @@ -79,10 +79,10 @@ import scala.annotation.{ nowarn, tailrec } case Some((_, v)) => v case None => raise(s"`$v1.$f`: not field named `$f`") } - case _ => + case _ => raise(s"`$v1.$f`: `$v1` is not a record") } - case Get(e1, e2) => + case Get(e1, e2) => val (v1, v2) = (run(e1), run(e2)) v1 match { case m: Map[Value, _] => @@ -90,10 +90,10 @@ import scala.annotation.{ nowarn, tailrec } case Some(vv2) => vv2 case None => ZeroValue } - case _ => + case _ => raise(s"`$v1($v2)` needs a dictionary but given `${v1.getClass}`") } - case Concat(e1, e2) => + case Concat(e1, e2) => val (v1, v2) = (run(e1), run(e2)) (v1, v2) match { case (RecordValue(fs1), RecordValue(fs2)) => @@ -105,13 +105,13 @@ import scala.annotation.{ nowarn, tailrec } RecordValue(fs1 ++ fs2.filter(x2 => !fs1m.contains(x2._1))) else raise(s"`concat($v1, $v2)` with different values for the same field name") - case _ => + case _ => raise(s"`concat($v1,$v2)` needs records, but given `${v1.getClass}`, `${v2.getClass}`") } - case Sum(k, v, e1, e2) => + case Sum(k, v, e1, e2) => val v1 = run(e1) v1 match { - case ZeroValue => + case ZeroValue => ZeroValue case range: Map[Value, Value] => if (range.isEmpty) @@ -140,33 +140,33 @@ import scala.annotation.{ nowarn, tailrec } else res } - case _ => + case _ => raise(s"`sum(<$k,$v> <- $v1) ...` doesn't have a dictionary range: `${v1.getClass}`") } - case Load(path, tp, _) => + case Load(path, tp, _) => tp match { case DictType(RecordType(fs), IntType, _) => val arr = Loader.loadTable(Table(path, fs, path)) arr.map(x => x -> 1).toMap - case _ => + case _ => raise(s"`load[$tp]('${path}')` only supports the type `{ < ... > -> int }`") } - case Promote(tp, e1) => + case Promote(tp, e1) => val v1 = run(e1) def notSupported() = raise(s"`promote[$tp]($v1)` not supported for value of type `${v1.getClass}`") tp match { - case tsrt: TropicalSemiRingType => + case tsrt: TropicalSemiRingType => val d1 = v1 match { case d: Double => d case i: Int => i.toDouble case _ => notSupported() } TropicalSemiRing(tsrt, d1) - case etp @ EnumSemiRingType(tp1) => + case etp @ EnumSemiRingType(tp1) => EnumSemiRing(etp, SingletonEnumSemiRing(v1)) case etp @ NullableSemiRingType(tp1) => NullableSemiRing(etp, Some(v1)) - case RealType => + case RealType => v1 match { case i: Int => i.toDouble case d: Double => d @@ -175,7 +175,7 @@ import scala.annotation.{ nowarn, tailrec } case nt: NullableSemiRing[Double] if nt.value.nonEmpty => nt.value.get case _ => notSupported() } - case IntType => + case IntType => v1 match { case i: Int => i case d: Double => d.toInt @@ -184,20 +184,20 @@ import scala.annotation.{ nowarn, tailrec } case nt: NullableSemiRing[Int] if nt.value.nonEmpty => nt.value.get case _ => notSupported() } - case _ => notSupported() + case _ => notSupported() } - case External(name, args) => + case External(name, args) => val vs = args.map(x => run(x)(ctx)) external(name, vs) } - def equal(v1: Value, v2: Value): Boolean = + def equal(v1: Value, v2: Value): Boolean = (v1, v2) match { case (ZeroValue, ZeroValue) => true case (_, ZeroValue) => equal(v2, v1) case (ZeroValue, v2) if isZero(v2) => true case _ => v1 == v2 } - def isZero(v: Value): Boolean = v match { + def isZero(v: Value): Boolean = v match { case false => true case 0 => true case 0.0 => true @@ -221,178 +221,177 @@ import scala.annotation.{ nowarn, tailrec } else if (isZero(rhs) && v1.contains(x._1)) v1.remove(x._1) } - def add(v1: Value, v2: Value): Value = + def add(v1: Value, v2: Value): Value = (v1, v2) match { - case (MinSumSemiRing(Some(s1)), MinSumSemiRing(Some(s2))) => + case (MinSumSemiRing(Some(s1)), MinSumSemiRing(Some(s2))) => MinSumSemiRing(Some(math.min(s1, s2))) - case (MaxSumSemiRing(Some(s1)), MaxSumSemiRing(Some(s2))) => + case (MaxSumSemiRing(Some(s1)), MaxSumSemiRing(Some(s2))) => MaxSumSemiRing(Some(math.max(s1, s2))) - case (MinProdSemiRing(Some(s1)), MinProdSemiRing(Some(s2))) => + case (MinProdSemiRing(Some(s1)), MinProdSemiRing(Some(s2))) => MinProdSemiRing(Some(math.min(s1, s2))) - case (MaxProdSemiRing(Some(s1)), MaxProdSemiRing(Some(s2))) => + case (MaxProdSemiRing(Some(s1)), MaxProdSemiRing(Some(s2))) => MaxProdSemiRing(Some(math.max(s1, s2))) - case (EnumSemiRing(tp1, en1), EnumSemiRing(tp2, en2)) if (tp1 == tp2) => + case (EnumSemiRing(tp1, en1), EnumSemiRing(tp2, en2)) if (tp1 == tp2) => (en1, en2) match { - case (TopEnumSemiRing, _) | (_, TopEnumSemiRing) => EnumSemiRing(tp1, TopEnumSemiRing) - case (BottomEnumSemiRing, _) => v2 - case (_, BottomEnumSemiRing) => v1 + case (TopEnumSemiRing, _) | (_, TopEnumSemiRing) => EnumSemiRing(tp1, TopEnumSemiRing) + case (BottomEnumSemiRing, _) => v2 + case (_, BottomEnumSemiRing) => v1 case (SingletonEnumSemiRing(s1), SingletonEnumSemiRing(s2)) => if (equal(s1, s2)) v1 else EnumSemiRing(tp1, TopEnumSemiRing) - case _ => + case _ => raise(s"enum addition not supported: `$en1`, `$en2`") } case (NullableSemiRing(tp1, en1), NullableSemiRing(tp2, en2)) if (tp1 == tp2) => (en1, en2) match { - case (None, _) => v2 - case (_, None) => v1 + case (None, _) => v2 + case (_, None) => v1 case (Some(s1), Some(s2)) => NullableSemiRing(tp1, Some(add(s1, s2))) - case _ => + case _ => raise(s"nullable addition not supported: `$en1`, `$en2`") } - case (ZeroValue, r2) => r2 - case (r1, ZeroValue) => r1 - case (r1: Int, r2: Int) => r1 + r2 - case (r1: Int, r2: Double) => r1.toDouble + r2 - case (r1: Double, r2: Int) => r1 + r2.toDouble - case (r1: Double, r2: Double) => r1 + r2 - case (RecordValue(vs1), RecordValue(vs2)) => + case (ZeroValue, r2) => r2 + case (r1, ZeroValue) => r1 + case (r1: Int, r2: Int) => r1 + r2 + case (r1: Int, r2: Double) => r1.toDouble + r2 + case (r1: Double, r2: Int) => r1 + r2.toDouble + case (r1: Double, r2: Double) => r1 + r2 + case (RecordValue(vs1), RecordValue(vs2)) => if (vs1.map(_._1) != vs2.map(_._1)) raise(s"record addition with incompatible types") else RecordValue(vs1.zip(vs2).map(x => x._1._1 -> add(x._1._2, x._2._2))) - case (r1: Map[Value, _], r2: Map[Value, _]) => + case (r1: Map[Value, _], r2: Map[Value, _]) => val res = (r1.keys ++ r2.keys) - .map( - k => - k -> { - (r1.get(k), r2.get(k)) match { - case (Some(vv1), Some(vv2)) => - add(vv1, vv2) - case (Some(vv1), None) => - vv1 - case (None, Some(vv2)) => - vv2 - case _ => raise("not implemented") - } + .map(k => + k -> { + (r1.get(k), r2.get(k)) match { + case (Some(vv1), Some(vv2)) => + add(vv1, vv2) + case (Some(vv1), None) => + vv1 + case (None, Some(vv2)) => + vv2 + case _ => raise("not implemented") + } } ) .toMap normalize(res) - case _ => raise(s"`$v1 + $v2` not handled") + case _ => raise(s"`$v1 + $v2` not handled") } - def normalize(v: Value): Value = v match { + def normalize(v: Value): Value = v match { case m: Map[Value, Value] => m.map(kv => kv._1 -> normalize(kv._2)).filter(kv => !isZero(kv._2)) - case RecordValue(vals) => + case RecordValue(vals) => RecordValue(vals.map(v => v._1 -> normalize(v._2))) - case _ => + case _ => v } - def mult(v1: Value, v2: Value): Value = + def mult(v1: Value, v2: Value): Value = (v1, v2) match { - case (MinSumSemiRing(Some(s1)), MinSumSemiRing(Some(s2))) => + case (MinSumSemiRing(Some(s1)), MinSumSemiRing(Some(s2))) => MinSumSemiRing(Some(s1 + s2)) - case (MaxSumSemiRing(Some(s1)), MaxSumSemiRing(Some(s2))) => + case (MaxSumSemiRing(Some(s1)), MaxSumSemiRing(Some(s2))) => MaxSumSemiRing(Some(s1 + s2)) - case (MinProdSemiRing(Some(s1)), MinProdSemiRing(Some(s2))) => + case (MinProdSemiRing(Some(s1)), MinProdSemiRing(Some(s2))) => MinProdSemiRing(Some(s1 * s2)) - case (MaxProdSemiRing(Some(s1)), MaxProdSemiRing(Some(s2))) => + case (MaxProdSemiRing(Some(s1)), MaxProdSemiRing(Some(s2))) => MaxProdSemiRing(Some(s1 * s2)) - case (EnumSemiRing(tp1, en1), EnumSemiRing(tp2, en2)) if (tp1 == tp2) => + case (EnumSemiRing(tp1, en1), EnumSemiRing(tp2, en2)) if (tp1 == tp2) => (en1, en2) match { - case (BottomEnumSemiRing, _) | (_, BottomEnumSemiRing) => EnumSemiRing(tp1, BottomEnumSemiRing) - case (TopEnumSemiRing, _) => v2 - case (_, TopEnumSemiRing) => v1 + case (BottomEnumSemiRing, _) | (_, BottomEnumSemiRing) => EnumSemiRing(tp1, BottomEnumSemiRing) + case (TopEnumSemiRing, _) => v2 + case (_, TopEnumSemiRing) => v1 case (SingletonEnumSemiRing(s1), SingletonEnumSemiRing(s2)) => if (equal(s1, s2)) v1 else EnumSemiRing(tp1, BottomEnumSemiRing) - case _ => + case _ => raise(s"enum multiplication not supported: `$en1`, `$en2`") } case (NullableSemiRing(tp1, en1), NullableSemiRing(tp2, en2)) if (tp1 == tp2) => (en1, en2) match { case (None, _) | (_, None) => NullableSemiRing(tp1, None) - case (Some(s1), Some(s2)) => + case (Some(s1), Some(s2)) => NullableSemiRing(tp1, Some(mult(s1, s2))) - case _ => + case _ => raise(s"nullable multiplciation not supported: `$en1`, `$en2`") } - case (ZeroValue, r2) => ZeroValue - case (r1, ZeroValue) => ZeroValue - case (r1: Int, r2: Int) => r1 * r2 - case (r1: Int, r2: Double) => r1 * r2 - case (r1: Double, r2: Int) => r1 * r2 - case (r1: Double, r2: Double) => r1 * r2 - case (r1: Int, r2: Map[Value, _]) => + case (ZeroValue, r2) => ZeroValue + case (r1, ZeroValue) => ZeroValue + case (r1: Int, r2: Int) => r1 * r2 + case (r1: Int, r2: Double) => r1 * r2 + case (r1: Double, r2: Int) => r1 * r2 + case (r1: Double, r2: Double) => r1 * r2 + case (r1: Int, r2: Map[Value, _]) => r2.map(kv => kv._1 -> mult(r1, kv._2)) - case (r1: Double, r2: Map[Value, _]) => + case (r1: Double, r2: Map[Value, _]) => r2.map(kv => kv._1 -> mult(r1, kv._2)) - case (r1: Map[Value, _], r2) => + case (r1: Map[Value, _], r2) => r1.map(kv => kv._1 -> mult(kv._2, r2)) - case _ => raise(s"`$v1 * $v2` not handled") + case _ => raise(s"`$v1 * $v2` not handled") } - def external(name: String, args: Seq[Value]): Value = { + def external(name: String, args: Seq[Value]): Value = { import ExternalFunctions.* - def raiseTp(tp: String) = raise(s"ext(`$name`, ...) expects $tp, but given: ${args.mkString(", ")}.") + def raiseTp(tp: String) = raise(s"ext(`$name`, ...) expects $tp, but given: ${args.mkString(", ")}.") import scala.language.implicitConversions implicit def bool2double(b: Boolean): Double = if (b) 1 else 0 name match { - case ParseDate.SYMBOL => + case ParseDate.SYMBOL => args(0) match { case v: String => val arr = v.split('-') val Array(y, m, d) = arr.map(_.toInt) DateValue(y * 10000 + m * 100 + d) - case _ => raiseTp("string") + case _ => raiseTp("string") } - case Year.SYMBOL => + case Year.SYMBOL => args(0) match { case DateValue(r) => r / 10000 case _ => raiseTp("date") } - case SubString.SYMBOL => + case SubString.SYMBOL => (args(0), args(1), args(2)) match { case (str: String, s: Int, l: Int) => str.substring(s, s + l) - case _ => raiseTp("string, int, int") + case _ => raiseTp("string, int, int") } case StrStartsWith.SYMBOL => (args(0), args(1)) match { case (str1: String, str2: String) => str1.startsWith(str2) case _ => raiseTp("string, string") } - case StrEndsWith.SYMBOL => + case StrEndsWith.SYMBOL => (args(0), args(1)) match { case (str1: String, str2: String) => str1.endsWith(str2) case _ => raiseTp("string, string") } - case StrContains.SYMBOL => + case StrContains.SYMBOL => (args(0), args(1)) match { case (str1: String, str2: String) => str1.contains(str2) case _ => raiseTp("string, string") } - case StrContainsN.SYMBOL => + case StrContainsN.SYMBOL => val as = args.map(_.asInstanceOf[String]) val (obj, xs) = as.head -> as.tail xs.forall(x => obj.contains(x)) - case StrIndexOf.SYMBOL => + case StrIndexOf.SYMBOL => (args(0), args(1), args(2)) match { case (str1: String, str2: String, idx: Int) => str1.indexOf(str2, idx) case _ => raiseTp("string, string, int") } - case Inv.SYMBOL => + case Inv.SYMBOL => args(0) match { case r: Int => 1.0 / r case r: Double => 1 / r case v => raise(s"`inv($v)` not handled") } - case name => + case name => raise(s"ext(`$name`, ${args.mkString(",")}) not handled") } } diff --git a/src/main/scala/sdql/driver/Main.scala b/src/main/scala/sdql/driver/Main.scala index 23e3d7c7..8e12fdfe 100644 --- a/src/main/scala/sdql/driver/Main.scala +++ b/src/main/scala/sdql/driver/Main.scala @@ -24,7 +24,7 @@ object Main { println({ Value.toString(res) }) println() } - case "compile" => + case "compile" => if (args.length < 3) { raise("usage: `run compile *`") } val dirPath = Path.of(args(1)) val fileNames = args.drop(2) @@ -50,7 +50,7 @@ object Main { val res = CppCodegen(llql, benchmarkRuns = n) CppCompile.writeFormat(filePath.toString, res) } - case arg => raise(s"`run $arg` not supported") + case arg => raise(s"`run $arg` not supported") } } } diff --git a/src/main/scala/sdql/frontend/Parser.scala b/src/main/scala/sdql/frontend/Parser.scala index 319254e3..66dd13c6 100644 --- a/src/main/scala/sdql/frontend/Parser.scala +++ b/src/main/scala/sdql/frontend/Parser.scala @@ -7,7 +7,7 @@ import fastparse.NoWhitespace.* import sdql.ir.* object Parser { - private def keywords(implicit ctx: P[?]) = P( + private def keywords(implicit ctx: P[?]) = P( StringIn( "if", "then", @@ -40,7 +40,7 @@ object Parser { "min_sum", "max_sum", "enum", - "nullable", + "nullable" ) ~ !idRest ) @@ -59,20 +59,15 @@ object Parser { private def fractional(implicit ctx: P[?]) = P("." ~ digits) private def integral(implicit ctx: P[?]) = P("0" | CharIn("1-9") ~ digits.?) - private def int(implicit ctx: P[?]) = P(CharIn("+\\-").? ~ integral).!.map( - x => Const(x.toInt) - ) + private def int(implicit ctx: P[?]) = P(CharIn("+\\-").? ~ integral).!.map(x => Const(x.toInt)) - private def number(implicit ctx: P[?]) = P(CharIn("+\\-").? ~ integral ~ fractional ~ exponent.?).!.map( - x => Const(x.toDouble) - ) + private def number(implicit ctx: P[?]) = + P(CharIn("+\\-").? ~ integral ~ fractional ~ exponent.?).!.map(x => Const(x.toDouble)) - private def dateValue(implicit ctx: P[?]) = P("date" ~ "(" ~ (integral.!.map(_.toInt)) ~ space ~ ")").map( - x => Const(DateValue(x)) - ) - private def concat(implicit ctx: P[?]) = P("concat" ~ "(" ~ expr ~ space ~ "," ~ space ~/ expr ~ ")").map( - x => Concat(x._1, x._2) - ) + private def dateValue(implicit ctx: P[?]) = + P("date" ~ "(" ~ (integral.!.map(_.toInt)) ~ space ~ ")").map(x => Const(DateValue(x))) + private def concat(implicit ctx: P[?]) = + P("concat" ~ "(" ~ expr ~ space ~ "," ~ space ~/ expr ~ ")").map(x => Concat(x._1, x._2)) private def stringChars(c: Char) = c != '\"' && c != '\\' private def hexDigit(implicit ctx: P[?]) = P(CharIn("0-9a-fA-F")) @@ -80,82 +75,81 @@ object Parser { private def escape(implicit ctx: P[?]) = P("\\" ~ (CharIn("\"/\\\\bfnrt") | unicodeEscape)) private def alpha(implicit ctx: P[?]) = P(CharPred(isLetter)) - private def tpeTropSR(implicit ctx: P[?]) = + private def tpeTropSR(implicit ctx: P[?]) = P(("mnpr" | "mxpr" | "mnsm" | "mxsm" | "min_prod" | "max_prod" | "min_sum" | "max_sum").!) .map(TropicalSemiRingType.apply) - private def tpeEnum(implicit ctx: P[?]) = + private def tpeEnum(implicit ctx: P[?]) = P("enum" ~ ("[" ~ space ~/ tpe ~/ space ~/ "]")).map(EnumSemiRingType.apply) - private def tpeNullable(implicit ctx: P[?]) = + private def tpeNullable(implicit ctx: P[?]) = P("nullable" ~ ("[" ~ space ~/ tpe ~/ space ~/ "]")).map(NullableSemiRingType.apply) - private def tpeBool(implicit ctx: P[?]) = P("bool").map(_ => BoolType) - private def tpeInt(implicit ctx: P[?]) = P("int").map(_ => IntType) - private def tpeReal(implicit ctx: P[?]) = P("double" | "real").map(_ => RealType) - private def tpeString(implicit ctx: P[?]) = P("string").map(_ => StringType()) - private def tpeVarChar(implicit ctx: P[?]) = + private def tpeBool(implicit ctx: P[?]) = P("bool").map(_ => BoolType) + private def tpeInt(implicit ctx: P[?]) = P("int").map(_ => IntType) + private def tpeReal(implicit ctx: P[?]) = P("double" | "real").map(_ => RealType) + private def tpeString(implicit ctx: P[?]) = P("string").map(_ => StringType()) + private def tpeVarChar(implicit ctx: P[?]) = P("varchar" ~ "(" ~ integral.!.map(_.toInt) ~ space ~ ")").map(VarCharType.apply) - private def tpeDate(implicit ctx: P[?]) = P("date").map(_ => DateType) - private def fieldTpe(implicit ctx: P[?]) = P(variable ~/ ":" ~ space ~/ tpe).map(x => Attribute(x._1.name, x._2)) - private def tpeRec(implicit ctx: P[?]) = + private def tpeDate(implicit ctx: P[?]) = P("date").map(_ => DateType) + private def fieldTpe(implicit ctx: P[?]) = P(variable ~/ ":" ~ space ~/ tpe).map(x => Attribute(x._1.name, x._2)) + private def tpeRec(implicit ctx: P[?]) = P("<" ~/ fieldTpe.rep(sep = ","./) ~ space ~/ ">").map(l => RecordType(l)) - private def tpeDict(implicit ctx: P[?]) = P(hinted.? ~ tpeDictNoHint).map { + private def tpeDict(implicit ctx: P[?]) = P(hinted.? ~ tpeDictNoHint).map { case (Some(hint), DictType(kt, vt, _)) => DictType(kt, vt, hint) case (None, dict) => dict } private def tpeDictNoHint(implicit ctx: P[?]) = P("{" ~/ tpe ~ space ~ "->" ~ space ~/ tpe ~ "}").map(x => DictType(x._1, x._2)) - private def tpe(implicit ctx: P[?]): P[Type] = + private def tpe(implicit ctx: P[?]): P[Type] = tpeBool | tpeInt | tpeReal | tpeString | tpeVarChar | tpeDate | tpeRec | tpeDict | tpeTropSR | tpeEnum | tpeNullable private def strChars(implicit ctx: P[?]) = P(CharsWhile(stringChars)) - private def string(implicit ctx: P[?]) = + private def string(implicit ctx: P[?]) = P(space ~ "\"" ~/ (strChars | escape).rep.! ~ "\"").map(Const.apply) private def fieldChars(implicit ctx: P[?]) = P(CharsWhile(_ != '`')) private def fieldConst(implicit ctx: P[?]) = P(space ~ "`" ~/ (fieldChars | escape).rep.! ~ "`").map(x => Const(Symbol(x))) - private def const(implicit ctx: P[?]) = `true` | `false` | unit | number | int | string | dateValue - private def idRest(implicit ctx: P[?]) = P(CharPred(c => isLetter(c) | isDigit(c) | c == '_').!).map(_(0)) - private def variable(implicit ctx: P[?]) = + private def const(implicit ctx: P[?]) = `true` | `false` | unit | number | int | string | dateValue + private def idRest(implicit ctx: P[?]) = P(CharPred(c => isLetter(c) | isDigit(c) | c == '_').!).map(_(0)) + private def variable(implicit ctx: P[?]) = P(space ~ !keywords ~ ((alpha | "_" | "$") ~ idRest.rep).! ~ space).map(Sym.apply) private def ifThenElse(implicit ctx: P[?]) = P(ifThen ~/ maybeElse.?).map { case (cond: Exp, thenp: Exp, Some(elsep: Exp)) => IfThenElse(cond, thenp, elsep) case (cond: Exp, thenp: Exp, None) => IfThenElse(cond, thenp, DictNode(Nil)) } - private def ifThen(implicit ctx: P[?]) = P("if" ~/ expr ~/ "then" ~/ expr) - private def maybeElse(implicit ctx: P[?]) = P("else" ~/ expr) + private def ifThen(implicit ctx: P[?]) = P("if" ~/ expr ~/ "then" ~/ expr) + private def maybeElse(implicit ctx: P[?]) = P("else" ~/ expr) private def letBinding(implicit ctx: P[?]) = P("let" ~/ variable ~/ "=" ~/ expr ~/ "in".? ~/ expr).map(x => LetBinding(x._1, x._2, x._3)) - private def sum(implicit ctx: P[?]) = + private def sum(implicit ctx: P[?]) = P( "sum" ~ space ~/ "(" ~/ "<" ~/ variable ~/ "," ~/ variable ~/ ">" ~/ space ~/ ("<-" | "in") ~/ expr ~/ ")" ~/ expr ).map(x => Sum(x._1, x._2, x._3, x._4)) - private def range(implicit ctx: P[?]) = P(("range(" ~ expr ~ space ~ ")")).map(RangeNode.apply) - private def ext(implicit ctx: P[?]) = + private def range(implicit ctx: P[?]) = P(("range(" ~ expr ~ space ~ ")")).map(RangeNode.apply) + private def ext(implicit ctx: P[?]) = P("ext(" ~/ fieldConst ~/ "," ~/ expr.rep(1, sep = ","./) ~ space ~/ ")") .map(x => External(x._1.v.asInstanceOf[Symbol].name, x._2)) - private def promote(implicit ctx: P[?]) = + private def promote(implicit ctx: P[?]) = P("promote" ~/ "[" ~/ tpe ~ space ~/ "]" ~/ "(" ~/ expr ~/ ")").map(x => Promote(x._1, x._2)) private def unique(implicit ctx: P[?]) = P("unique" ~/ "(" ~/ expr ~/ ")").map(Unique.apply) private def fieldValue(implicit ctx: P[?]) = P(variable ~/ "=" ~/ expr).map(x => (x._1.name, x._2)) private def rec(implicit ctx: P[?]) = P("<" ~/ fieldValue.rep(sep = ","./) ~ space ~/ ">").map(RecNode.apply) - private def load(implicit ctx: P[?]) = - P("load" ~/ "[" ~/ tpe ~ space ~/ "]" ~/ "(" ~/ string ~/ skipCols.? ~ ")") - .map(x => { - val skipCols = x._3 match { - case Some(cols) => cols - case None => SetNode(Nil) - } - Load(x._2.v.asInstanceOf[String], x._1, skipCols) - }) + private def load(implicit ctx: P[?]) = + P("load" ~/ "[" ~/ tpe ~ space ~/ "]" ~/ "(" ~/ string ~/ skipCols.? ~ ")").map { x => + val skipCols = x._3 match { + case Some(cols) => cols + case None => SetNode(Nil) + } + Load(x._2.v.asInstanceOf[String], x._1, skipCols) + } private def skipCols(implicit ctx: P[?]) = P("," ~ space ~ set) - private def dictOrSet(implicit ctx: P[?]) = dict | set - private def keyNoValue(implicit ctx: P[?]) = P(expr ~/ !"->") - private def keyValue(implicit ctx: P[?]) = P(expr ~ "->" ~/ expr) - private def set(implicit ctx: P[?]) = P("{" ~ keyNoValue.rep(sep = ",") ~ space ~ "}").map(SetNode.apply) - private def dict(implicit ctx: P[?]) = P(hinted.? ~ dictNoHint).map { + private def dictOrSet(implicit ctx: P[?]) = dict | set + private def keyNoValue(implicit ctx: P[?]) = P(expr ~/ !"->") + private def keyValue(implicit ctx: P[?]) = P(expr ~ "->" ~/ expr) + private def set(implicit ctx: P[?]) = P("{" ~ keyNoValue.rep(sep = ",") ~ space ~ "}").map(SetNode.apply) + private def dict(implicit ctx: P[?]) = P(hinted.? ~ dictNoHint).map { case (Some(hint), DictNode(map, _)) => DictNode(map, hint) case (None, dict) => dict } @@ -176,61 +170,54 @@ object Parser { ext | parens) ~ space ) - private def neg(implicit ctx: P[?]): P[Neg] = P("-" ~ !(">") ~ factor).map(Neg.apply) - private def not(implicit ctx: P[?]): P[Exp] = P("!" ~ factor).map(Not.apply) - private def factorMult(implicit ctx: P[?]) = + private def neg(implicit ctx: P[?]): P[Neg] = P("-" ~ !(">") ~ factor).map(Neg.apply) + private def not(implicit ctx: P[?]): P[Exp] = P("!" ~ factor).map(Not.apply) + private def factorMult(implicit ctx: P[?]) = P( factor ~ ((".".! ~/ variable) | ("^".! ~/ factor) | ("(".! ~/ expr ~/ ")" ~ space)).rep - ).map( - x => - x._2.foldLeft(x._1)( - (acc, cur) => - cur match { - case (".", Sym(name)) => FieldNode(acc, name) - case ("(", e) => Get(acc, e) - case ("^", _) => - cur._2 match { - case Const(0) => Const(1) - case Const(1) => acc - case Const(n: Int) => (1 to n).map(_ => acc).reduceLeft(Mult.apply) - case _ => raise("Parsing for power failed") - } - case _ => raise("Parsing for factorMult failed") - } + ).map(x => + x._2.foldLeft(x._1)((acc, cur) => + cur match { + case (".", Sym(name)) => FieldNode(acc, name) + case ("(", e) => Get(acc, e) + case ("^", _) => + cur._2 match { + case Const(0) => Const(1) + case Const(1) => acc + case Const(n: Int) => (1 to n).map(_ => acc).reduceLeft(Mult.apply) + case _ => raise("Parsing for power failed") + } + case _ => raise("Parsing for factorMult failed") + } ) ) - private def divMul(implicit ctx: P[?]) = - P(factorMult ~ (StringIn("*", "/", "|", "&&", "||").! ~/ factorMult).rep).map( - x => - x._2.foldLeft(x._1)( - (acc, cur) => - cur._1 match { - case "*" => Mult(acc, cur._2) - case "/" => Mult(acc, ExternalFunctions.Inv(cur._2)) - case "&&" => And(acc, cur._2) - case "||" => Or(acc, cur._2) - } + private def divMul(implicit ctx: P[?]) = + P(factorMult ~ (StringIn("*", "/", "|", "&&", "||").! ~/ factorMult).rep).map(x => + x._2.foldLeft(x._1)((acc, cur) => + cur._1 match { + case "*" => Mult(acc, cur._2) + case "/" => Mult(acc, ExternalFunctions.Inv(cur._2)) + case "&&" => And(acc, cur._2) + case "||" => Or(acc, cur._2) + } ) ) - private def addSub(implicit ctx: P[?]) = - P(divMul ~ (StringIn("+", "-").! ~ !(">") ~/ divMul).rep).map( - x => - x._2.foldLeft(x._1)( - (acc, cur) => - cur._1 match { - case "+" => Add(acc, cur._2) - case "-" => Add(acc, Neg(cur._2)) - } + private def addSub(implicit ctx: P[?]) = + P(divMul ~ (StringIn("+", "-").! ~ !(">") ~/ divMul).rep).map(x => + x._2.foldLeft(x._1)((acc, cur) => + cur._1 match { + case "+" => Add(acc, cur._2) + case "-" => Add(acc, Neg(cur._2)) + } ) ) - private def addSubCmp(implicit ctx: P[?]) = - P(addSub ~ (StringIn("<", "==", "<=", "!=", "∈").! ~/ addSub).?).map( - x => - x._2 match { - case Some((op, y)) => Cmp(x._1, y, op) - case None => x._1 + private def addSubCmp(implicit ctx: P[?]) = + P(addSub ~ (StringIn("<", "==", "<=", "!=", "∈").! ~/ addSub).?).map(x => + x._2 match { + case Some((op, y)) => Cmp(x._1, y, op) + case None => x._1 } ) private def parens(implicit ctx: P[?]) = P("(" ~/ expr ~/ ")") diff --git a/src/main/scala/sdql/frontend/SourceCode.scala b/src/main/scala/sdql/frontend/SourceCode.scala index 0d82803d..9aff4b08 100644 --- a/src/main/scala/sdql/frontend/SourceCode.scala +++ b/src/main/scala/sdql/frontend/SourceCode.scala @@ -7,9 +7,10 @@ class SourceCode(val fileName: String, val exp: Exp) object SourceCode { def fromFile(fileName: String): SourceCode = { - val source = scala.io.Source.fromFile(fileName) - val content = try source.mkString - finally source.close() + val source = scala.io.Source.fromFile(fileName) + val content = + try source.mkString + finally source.close() new SourceCode(fileName, Parser(content)) } } diff --git a/src/main/scala/sdql/frontend/package.scala b/src/main/scala/sdql/frontend/package.scala index a82db712..cea1f54f 100644 --- a/src/main/scala/sdql/frontend/package.scala +++ b/src/main/scala/sdql/frontend/package.scala @@ -4,7 +4,7 @@ import sdql.ir.* package object frontend { implicit class Interpolator(val sc: StringContext) { def valueToString(v: Any): String = Value.toString(v) - def sdql(args: Any*): Exp = { + def sdql(args: Any*): Exp = { val strings = sc.parts.iterator val expressions = args.iterator val buf = new StringBuffer(strings.next()) diff --git a/src/main/scala/sdql/ir/Exp.scala b/src/main/scala/sdql/ir/Exp.scala index 3918f988..91839178 100644 --- a/src/main/scala/sdql/ir/Exp.scala +++ b/src/main/scala/sdql/ir/Exp.scala @@ -6,9 +6,8 @@ import munit.Assertions.munitPrint import scala.annotation.tailrec /** - * This trait models expressions used in the SDQL language, without also - * computing them. It can be used to generate an abstract syntax tree of a - * given program. + * This trait models expressions used in the SDQL language, without also computing them. It can be used to generate an + * abstract syntax tree of a given program. */ sealed trait Exp { def prettyPrint: String = munitPrint(this) @@ -20,23 +19,22 @@ sealed trait Exp { } /** - * This class models a symbol, e.g. "x1" or "x2", and it also includes options - * to get a fresh symbol, or to reset the used symbol counter (start over from - * x1). + * This class models a symbol, e.g. "x1" or "x2", and it also includes options to get a fresh symbol, or to reset the + * used symbol counter (start over from x1). */ case class Sym(name: String) extends Exp -object Sym { +object Sym { private val DEFAULT_NAME = "x" private val START_ID = 1 private var lastId = START_ID.toLong /** Get a fresh symbol, i.e. xi, where i is the smallest number not used. */ - def fresh: Sym = fresh(DEFAULT_NAME) + def fresh: Sym = fresh(DEFAULT_NAME) private def fresh(name: String): Sym = { val cur = freshId Sym(s"$name$cur") } - private def freshId: Long = { + private def freshId: Long = { val cur = lastId lastId += 1 cur @@ -49,15 +47,17 @@ object Sym { /** * A constant - * @param v integer, double, string, boolean + * @param v + * integer, double, string, boolean */ case class Const(v: Any) extends Exp /** * A record (tuple), with field labels - * @param values a sequence of expression values with a field label + * @param values + * a sequence of expression values with a field label */ -case class RecNode(values: Seq[(Field, Exp)]) extends Exp { +case class RecNode(values: Seq[(Field, Exp)]) extends Exp { def apply(name: Field): Option[Exp] = values.find(_._1 == name).map(_._2) def concat(other: RecNode): RecNode = other match { @@ -75,7 +75,8 @@ case class RecNode(values: Seq[(Field, Exp)]) extends Exp { /** * A dictionary that maps expressions to other expressions - * @param map a dictionary from expression to other expressions + * @param map + * a dictionary from expression to other expressions */ case class DictNode(map: Seq[(Exp, Exp)], hint: DictHint = PHmap()) extends Exp { @tailrec @@ -84,66 +85,78 @@ case class DictNode(map: Seq[(Exp, Exp)], hint: DictHint = PHmap()) extends Exp case DictNode(map, _) if map.length != 1 => raise(s"unsupported: $this") case _ => this } - def isSetNode: Boolean = this == SetNode(this.map.map(_._1)) - def toSkipColsSet: Set[String] = + def isSetNode: Boolean = this == SetNode(this.map.map(_._1)) + def toSkipColsSet: Set[String] = this.map.map(_._1).map(k => (k: @unchecked) match { case Const(s: String) => s }).toSet } sealed trait DictHint -case class PHmap(e: Option[Exp] = None) extends DictHint -case class SmallVecDict(size: Int) extends DictHint -case class SmallVecDicts(size: Int) extends DictHint -case class Vec(size: Option[Int] = None) extends DictHint +case class PHmap(e: Option[Exp] = None) extends DictHint +case class SmallVecDict(size: Int) extends DictHint +case class SmallVecDicts(size: Int) extends DictHint +case class Vec(size: Option[Int] = None) extends DictHint /** * Integer numbers between 0 and n - * @param e an expression evaluating to n + * @param e + * an expression evaluating to n */ case class RangeNode(e: Exp) extends Exp /** * Addition of two expressions - * @param e1 exp1 - * @param e2 exp2 + * @param e1 + * exp1 + * @param e2 + * exp2 */ case class Add(e1: Exp, e2: Exp) extends Exp /** * Multiplication of two expressions - * @param e1 exp1 - * @param e2 exp2 + * @param e1 + * exp1 + * @param e2 + * exp2 */ case class Mult(e1: Exp, e2: Exp) extends Exp /** * Negative of an expression - * @param e exp1 + * @param e + * exp1 */ case class Neg(e: Exp) extends Exp /** * Comparison of two expressions - * @param e1 exp1 - * @param e2 exp2 - * @param cmp Comparison operator + * @param e1 + * exp1 + * @param e2 + * exp2 + * @param cmp + * Comparison operator */ case class Cmp(e1: Exp, e2: Exp, cmp: String) extends Exp /** * Conditional statement - * @param cond condition - * @param thenp expression if true - * @param elsep expression if false + * @param cond + * condition + * @param thenp + * expression if true + * @param elsep + * expression if false */ -case class IfThenElse(cond: Exp, thenp: Exp, elsep: Exp) extends Exp -case class FieldNode(e: Exp, f: Field) extends Exp -case class Sum(key: Sym, value: Sym, e1: Exp, body: Exp) extends Exp -case class Get(e1: Exp, e2: Exp) extends Exp -case class Concat(e1: Exp, e2: Exp) extends Exp -case class LetBinding(x: Sym, e1: Exp, e2: Exp) extends Exp { +case class IfThenElse(cond: Exp, thenp: Exp, elsep: Exp) extends Exp +case class FieldNode(e: Exp, f: Field) extends Exp +case class Sum(key: Sym, value: Sym, e1: Exp, body: Exp) extends Exp +case class Get(e1: Exp, e2: Exp) extends Exp +case class Concat(e1: Exp, e2: Exp) extends Exp +case class LetBinding(x: Sym, e1: Exp, e2: Exp) extends Exp { override def hashCode(): Int = this match { case LetBindingN(xs, res) => xs.map(xe => xe._1.hashCode() + xe._2.hashCode()).sum + res.hashCode() - case _ => super.hashCode() + case _ => super.hashCode() } } case class Load(path: String, tp: Type, skipCols: DictNode = SetNode(Nil)) extends Exp @@ -152,15 +165,14 @@ case class Promote(tp: Type, e: Exp) extends Exp case class External(name: String, args: Seq[Exp]) extends Exp case class Unique(e: Exp) extends Exp -object SetNode { +object SetNode { def apply(es: Seq[Exp]): DictNode = DictNode(es.map(x => x -> Const(1))) def fromSkipColsSet(set: Set[String]): DictNode = SetNode(set.map(Const.apply).toSeq) } /** - * This object models a sequence of let bindings, where the bindings can either - * be applied in an expression body, or unapplied and returned from a given - * expression. + * This object models a sequence of let bindings, where the bindings can either be applied in an expression body, or + * unapplied and returned from a given expression. */ object LetBindingN { def apply(bindings: Seq[(Sym, Exp)], body: Exp): Exp = @@ -170,11 +182,14 @@ object LetBindingN { @tailrec def rec(exp: Exp, res: Res): Res = exp match { case LetBinding(x, e1, e2) => - rec(e2, res match { - case Some((seq, _)) => Some((seq :+ (x -> e1)) -> e2) - case None => Some(Seq(x -> e1) -> e2) - }) - case _ => res + rec( + e2, + res match { + case Some((seq, _)) => Some((seq :+ (x -> e1)) -> e2) + case None => Some(Seq(x -> e1) -> e2) + } + ) + case _ => res } } @@ -183,7 +198,7 @@ object And { def apply(a: Exp, b: Exp): Exp = IfThenElse(a, b, Const(false)) } object Or { def apply(a: Exp, b: Exp): Exp = IfThenElse(a, Const(true), b) } object Not { - def apply(a: Exp): Exp = IfThenElse(a, Const(false), Const(true)) + def apply(a: Exp): Exp = IfThenElse(a, Const(false), Const(true)) def unapply(e: Exp): Option[Exp] = e match { case IfThenElse(a, Const(false), Const(true)) => Some(a) case _ => None @@ -191,7 +206,7 @@ object Not { } object SingleDict { - def apply(k: Exp, v: Exp): Exp = DictNode(Seq((k, v))) + def apply(k: Exp, v: Exp): Exp = DictNode(Seq((k, v))) def unapply(exp: Exp): Option[(Exp, Exp)] = exp match { case DictNode(Seq((k, v)), _) => Some((k, v)) case _ => None @@ -204,6 +219,6 @@ object SingleDict { * The expressions below bridge the gap between functional and imperative styles, simplifying code generation in C++. */ sealed trait LLQL -case class Initialise(tpe: Type, e: Sum) extends Exp with LLQL +case class Initialise(tpe: Type, e: Sum) extends Exp with LLQL case class Update(e: Exp, agg: Aggregation, dest: Sym) extends Exp with LLQL case class Modify(e: Exp, dest: Sym) extends Exp with LLQL diff --git a/src/main/scala/sdql/ir/ExternalFunctions.scala b/src/main/scala/sdql/ir/ExternalFunctions.scala index 31e556c9..765e5a65 100644 --- a/src/main/scala/sdql/ir/ExternalFunctions.scala +++ b/src/main/scala/sdql/ir/ExternalFunctions.scala @@ -2,14 +2,14 @@ package sdql package ir abstract class ExternalFactory(symbol: String) { - val SYMBOL: String = symbol - def apply(es: Exp*): Exp = External(symbol, es.toSeq) + val SYMBOL: String = symbol + def apply(es: Exp*): Exp = External(symbol, es.toSeq) def unapplySeq(e: Exp): Option[Seq[Exp]] = e match { case External(sym, seq) if sym == symbol => Some(seq) case _ => None } } -object ExternalFunctions { +object ExternalFunctions { object TopN extends ExternalFactory("TopN") object ConstantString extends ExternalFactory("ConstantString") object StrContains extends ExternalFactory("StrContains") diff --git a/src/main/scala/sdql/ir/SemiRing.scala b/src/main/scala/sdql/ir/SemiRing.scala index a1fddb9b..1f6cc161 100644 --- a/src/main/scala/sdql/ir/SemiRing.scala +++ b/src/main/scala/sdql/ir/SemiRing.scala @@ -13,14 +13,14 @@ object TropicalSemiRing { case t => raise(s"unexpected: ${t.prettyPrint}") } @nowarn - def unapply(e: Any): Option[(TropicalSemiRingType, Option[Double])] = e match { + def unapply(e: Any): Option[(TropicalSemiRingType, Option[Double])] = e match { case t: TropicalSemiRing[Double] => Some((t.kind, t.value)) case _ => None } - val MinSumSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = false, isProd = false) - val MaxSumSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = true, isProd = false) - val MinProdSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = false, isProd = true) - val MaxProdSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = true, isProd = true) + val MinSumSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = false, isProd = false) + val MaxSumSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = true, isProd = false) + val MinProdSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = false, isProd = true) + val MaxProdSemiRingType: TropicalSemiRingType = TropicalSemiRingType(isMax = true, isProd = true) } case class MinSumSemiRing(override val value: Option[Double]) @@ -34,22 +34,24 @@ case class MaxProdSemiRing(override val value: Option[Double]) // type isn't known at time of parsing, it can be populated later by type inference case class TropicalSemiRingType(isMax: Boolean, isProd: Boolean, tp: Option[Type] = None) - extends CustomSemiRingType(s"${if (isMax) "max" else "min"}_${if (isProd) "prod" else "sum"}", - Seq(isMax, isProd, tp)) + extends CustomSemiRingType( + s"${if (isMax) "max" else "min"}_${if (isProd) "prod" else "sum"}", + Seq(isMax, isProd, tp) + ) object TropicalSemiRingType { - def apply(name: String): TropicalSemiRingType = name match { + def apply(name: String): TropicalSemiRingType = name match { case "min_sum" | "mnsm" => TropicalSemiRingType(isMax = false, isProd = false) case "max_sum" | "mxsm" => TropicalSemiRingType(isMax = true, isProd = false) case "min_prod" | "mnpr" => TropicalSemiRingType(isMax = false, isProd = true) case "max_prod" | "mxpr" => TropicalSemiRingType(isMax = true, isProd = true) } - def apply(name: String, tp: Type): TropicalSemiRingType = this.pack(this.apply(name), tp) + def apply(name: String, tp: Type): TropicalSemiRingType = this.pack(this.apply(name), tp) def pack(tsrt: TropicalSemiRingType, tp: Type): TropicalSemiRingType = tsrt match { case TropicalSemiRingType(isMax, isProd, None) => TropicalSemiRingType(isMax, isProd, Some(tp)) case TropicalSemiRingType(_, _, Some(tpe)) if tpe != tp => raise(s"$tpe ≠ $tp") case _ => tsrt } - def unpack(tp: Type): Type = tp match { + def unpack(tp: Type): Type = tp match { case TropicalSemiRingType(_, _, Some(tp)) => tp case tsrt @ TropicalSemiRingType(_, _, None) => raise(s"${tsrt.simpleName} is missing type information") case _ => tp @@ -62,12 +64,12 @@ sealed trait EnumSemiRingValue[+T] { case SingletonEnumSemiRing(_) => true case _ => false } - def get: T = this match { + def get: T = this match { case SingletonEnumSemiRing(v) => v case _ => raise("not implemented") } } -case object TopEnumSemiRing extends EnumSemiRingValue[Nothing] +case object TopEnumSemiRing extends EnumSemiRingValue[Nothing] case object BottomEnumSemiRing extends EnumSemiRingValue[Nothing] case class SingletonEnumSemiRing[+T](value: T) extends EnumSemiRingValue[T] diff --git a/src/main/scala/sdql/ir/Type.scala b/src/main/scala/sdql/ir/Type.scala index d4754c0b..58e3b454 100644 --- a/src/main/scala/sdql/ir/Type.scala +++ b/src/main/scala/sdql/ir/Type.scala @@ -4,20 +4,20 @@ package ir sealed trait Type { def =~=(o: Type): Boolean = equals(o) def isScalar: Boolean = ScalarType.isScalar(this) - def prettyPrint: String = this match { + def prettyPrint: String = this match { case DictType(kt, vt, _) => s"{${kt.prettyPrint} -> ${vt.prettyPrint}}" - case RecordType(attrs) => + case RecordType(attrs) => attrs.map(_.tpe.prettyPrint).mkString("<", ", ", ">") - case _ => + case _ => this.simpleName } - def simpleName: String = { + def simpleName: String = { val name = this.getClass.getSimpleName if (name.endsWith("$")) name.dropRight(1) else name } } -case class StringType(maxLen: Option[Int] = None) extends Type +case class StringType(maxLen: Option[Int] = None) extends Type case object RealType extends Type case object BoolType extends Type case object IntType extends Type @@ -25,28 +25,28 @@ case object DateType extends Ty case class DictType(key: Type, value: Type, hint: DictHint = PHmap()) extends Type object SetType { def apply(key: Type): DictType = DictType(key, IntType) } case class RecordType(attrs: Seq[Attribute]) extends Type { - override def equals(o: Any): Boolean = o match { + override def equals(o: Any): Boolean = o match { case RecordType(attrs2) if attrs.size == attrs2.size => attrs.zip(attrs2).forall(x => x._1.name == x._2.name && x._1.tpe == x._2.tpe) - case _ => false + case _ => false } - override def =~=(o: Type): Boolean = o match { + override def =~=(o: Type): Boolean = o match { case RecordType(attrs2) if attrs.size == attrs2.size => attrs.zip(attrs2).forall(x => x._1.name == x._2.name && x._1.tpe =~= x._2.tpe) - case _ => false + case _ => false } - override def hashCode(): Int = + override def hashCode(): Int = attrs.map(_.name).hashCode() - def indexOf(name: Field): Option[Int] = { + def indexOf(name: Field): Option[Int] = { val names = attrs.map(_.name) assert(names.diff(names.distinct).isEmpty) names.zipWithIndex.find(_._1 == name).map(_._2) } - def apply(name: Field): Option[Type] = attrs.find(_.name == name).map(_.tpe) + def apply(name: Field): Option[Type] = attrs.find(_.name == name).map(_.tpe) def concat(other: RecordType): RecordType = other match { case RecordType(attrs2) => val (fs1m, fs2m) = attrs.map(x1 => (x1.name, x1.tpe)).toMap -> attrs2.map(x2 => (x2.name, x2.tpe)).toMap - val common = + val common = attrs.filter(x1 => fs2m.contains(x1.name)).map(x1 => (x1.name, x1.tpe, fs2m(x1.name))) if (common.isEmpty) RecordType(attrs ++ attrs2) @@ -61,7 +61,7 @@ object VarCharType { def apply(maxLen: Int): Type = StringType(Some(maxLen)) } object ScalarType { def unapply(tp: Type): Option[Type] = Some(tp).filter(isScalar) - def isScalar(tp: Type): Boolean = tp match { + def isScalar(tp: Type): Boolean = tp match { case RealType | IntType | _: StringType | DateType | BoolType => true case _ => false } diff --git a/src/main/scala/sdql/ir/Value.scala b/src/main/scala/sdql/ir/Value.scala index eea35a8a..d5294634 100644 --- a/src/main/scala/sdql/ir/Value.scala +++ b/src/main/scala/sdql/ir/Value.scala @@ -7,20 +7,20 @@ case object ZeroValue object Value { def toString(v: Value): String = v match { - case b: Boolean => b.toString - case DateValue(v) => + case b: Boolean => b.toString + case DateValue(v) => val s = v.toString s"${s.substring(0, 4)}-${s.substring(4, 6)}-${s.substring(6, 8)}" - case s: String => "\"" + s + "\"" - case d: Double => d.toString - case i: Int => i.toString - case m: Map[?, ?] => + case s: String => "\"" + s + "\"" + case d: Double => d.toString + case i: Int => i.toString + case m: Map[?, ?] => m.map(kv => s"${toString(kv._1)} -> ${toString(kv._2)}").mkString("{", ", ", "}") case RecordValue(vals) => vals.map(fv => s"${fv._1} = ${fv._2}").mkString("<", ", ", ">") - case ZeroValue => + case ZeroValue => ZeroValue.toString - case _ => raise(s"Doesn't know how to convert `$v` to string") + case _ => raise(s"Doesn't know how to convert `$v` to string") } def normalize(v: Value): Value = raise("normalize not supported yet") diff --git a/src/main/scala/sdql/storage/FastScanner.scala b/src/main/scala/sdql/storage/FastScanner.scala index c8d3910a..eecef58e 100644 --- a/src/main/scala/sdql/storage/FastScanner.scala +++ b/src/main/scala/sdql/storage/FastScanner.scala @@ -5,10 +5,9 @@ import java.io.{ BufferedReader, FileReader } /** * Code from: https://github.com/epfldata/dblab/blob/develop/components/ - * src/main/scala/ch/epfl/data/dblab/storagemanager/ + * src/main/scala/ch/epfl/data/dblab/storagemanager/ * * An efficient Scanner defined for reading from files. - * */ class FastScanner(filename: String) { @@ -58,7 +57,7 @@ class FastScanner(filename: String) { def next_char(): Char = { byteRead = br.read() - val del = br.read() //delimiter + val del = br.read() // delimiter if ((del != delimiter) && (del != '\n')) throw new RuntimeException("Expected delimiter after char. Not found. Sorry!") byteRead.asInstanceOf[Char] @@ -78,11 +77,11 @@ class FastScanner(filename: String) { cnt } - private val buffer = new Array[Byte](1 << 10) + private val buffer = new Array[Byte](1 << 10) def next_string: String = { java.util.Arrays.fill(buffer, 0.toByte) byteRead = br.read() - var cnt = 0 + var cnt = 0 while (br.ready() && (byteRead != delimiter) && (byteRead != '\n')) { buffer(cnt) = byteRead.asInstanceOf[Byte] byteRead = br.read() @@ -98,8 +97,8 @@ class FastScanner(filename: String) { val year = next_int() val month = next_int() delimiter = '|' - val day = next_int() - //val date_str = year + "-" + month + "-" + day + val day = next_int() + // val date_str = year + "-" + month + "-" + day year * 10000 + month * 100 + day } diff --git a/src/main/scala/sdql/storage/Loader.scala b/src/main/scala/sdql/storage/Loader.scala index 314de586..afd9c132 100644 --- a/src/main/scala/sdql/storage/Loader.scala +++ b/src/main/scala/sdql/storage/Loader.scala @@ -5,10 +5,9 @@ import sdql.ir.* /** * Code from: https://github.com/epfldata/dblab/blob/develop/components/ - * src/main/scala/ch/epfl/data/dblab/storagemanager/ + * src/main/scala/ch/epfl/data/dblab/storagemanager/ * * An efficient Scanner defined for reading from files. - * */ object Loader { @@ -25,14 +24,13 @@ object Loader { var i = 0 while (i < size && ldr.hasNext()) { - val values = table.attributes.map( - arg => - arg.tpe match { - case IntType => ldr.next_int() - case RealType => ldr.next_double() - case _: StringType => ldr.next_string - case DateType => DateValue(ldr.next_date) - case t => raise(s"Not handled type `$t` in the loader.") + val values = table.attributes.map(arg => + arg.tpe match { + case IntType => ldr.next_int() + case RealType => ldr.next_double() + case _: StringType => ldr.next_string + case DateType => DateValue(ldr.next_date) + case t => raise(s"Not handled type `$t` in the loader.") } ) arr(i) = RecordValue(table.attributes.map(_.name).zip(values).toSeq) diff --git a/src/main/scala/sdql/transformations/Rewriter.scala b/src/main/scala/sdql/transformations/Rewriter.scala index da8bd750..a7d910a3 100644 --- a/src/main/scala/sdql/transformations/Rewriter.scala +++ b/src/main/scala/sdql/transformations/Rewriter.scala @@ -16,7 +16,7 @@ class TermRewriter(transformations: Transformation*) extends Transformation { def apply(e: Exp): Exp = transformations.foldLeft(e)((acc, f) => f.apply(acc)) } -object TermRewriter { def apply(transformations: Transformation*): TermRewriter = new TermRewriter(transformations *) } +object TermRewriter { def apply(transformations: Transformation*): TermRewriter = new TermRewriter(transformations*) } /** Applies all transformations and lowers an expression to LLQL */ object Rewriter { @@ -32,60 +32,66 @@ object Rewriter { // 0-ary case _: Sym | _: Const | _: Load => e // 1-ary - case Neg(e) => Neg(f(e)) - case FieldNode(e, field) => FieldNode(f(e), field) - case Promote(tp, e) => Promote(tp, f(e)) - case RangeNode(e) => RangeNode(f(e)) - case Unique(e) => Unique(f(e)) + case Neg(e) => Neg(f(e)) + case FieldNode(e, field) => FieldNode(f(e), field) + case Promote(tp, e) => Promote(tp, f(e)) + case RangeNode(e) => RangeNode(f(e)) + case Unique(e) => Unique(f(e)) // 2-ary - case Add(e1, e2) => Add(f(e1), f(e2)) - case Mult(e1, e2) => Mult(f(e1), f(e2)) - case Cmp(e1, e2, cmp) => Cmp(f(e1), f(e2), cmp) - case Sum(key, value, e1, e2) => Sum(key, value, f(e1), f(e2)) - case Get(e1, e2) => Get(f(e1), f(e2)) - case Concat(e1, e2) => Concat(f(e1), f(e2)) - case LetBinding(x, e1, e2) => LetBinding(x, f(e1), f(e2)) + case Add(e1, e2) => Add(f(e1), f(e2)) + case Mult(e1, e2) => Mult(f(e1), f(e2)) + case Cmp(e1, e2, cmp) => Cmp(f(e1), f(e2), cmp) + case Sum(key, value, e1, e2) => Sum(key, value, f(e1), f(e2)) + case Get(e1, e2) => Get(f(e1), f(e2)) + case Concat(e1, e2) => Concat(f(e1), f(e2)) + case LetBinding(x, e1, e2) => LetBinding(x, f(e1), f(e2)) // 3-ary - case IfThenElse(e1, e2, e3) => IfThenElse(f(e1), f(e2), f(e3)) + case IfThenElse(e1, e2, e3) => IfThenElse(f(e1), f(e2), f(e3)) // n-ary - case RecNode(values) => RecNode(values.map(v => (v._1, f(v._2)))) - case DictNode(map, hint) => - DictNode(map.map(x => (f(x._1), f(x._2))), hint match { - case PHmap(Some(e)) => PHmap(Some(f(e))) - case _ => hint - }) - case External(name, args) => External(name, args.map(f)) - case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") + case RecNode(values) => RecNode(values.map(v => (v._1, f(v._2)))) + case DictNode(map, hint) => + DictNode( + map.map(x => (f(x._1), f(x._2))), + hint match { + case PHmap(Some(e)) => PHmap(Some(f(e))) + case _ => hint + } + ) + case External(name, args) => External(name, args.map(f)) + case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") } def mapInnerReduce[T](f: Exp => T, g: (T, T) => T, default: T)(e: Exp): T = e match { // 0-ary case _: Sym | _: Const | _: Load => default // 1-ary - case Neg(e) => f(e) - case FieldNode(e, _) => f(e) - case Promote(_, e) => f(e) - case RangeNode(e) => f(e) - case Unique(e) => f(e) + case Neg(e) => f(e) + case FieldNode(e, _) => f(e) + case Promote(_, e) => f(e) + case RangeNode(e) => f(e) + case Unique(e) => f(e) // 2-ary - case Add(e1, e2) => g(f(e1), f(e2)) - case Mult(e1, e2) => g(f(e1), f(e2)) - case Cmp(e1, e2, _) => g(f(e1), f(e2)) - case Sum(_, _, e1, e2) => g(f(e1), f(e2)) - case Get(e1, e2) => g(f(e1), f(e2)) - case Concat(e1, e2) => g(f(e1), f(e2)) - case LetBinding(_, e1, e2) => g(f(e1), f(e2)) + case Add(e1, e2) => g(f(e1), f(e2)) + case Mult(e1, e2) => g(f(e1), f(e2)) + case Cmp(e1, e2, _) => g(f(e1), f(e2)) + case Sum(_, _, e1, e2) => g(f(e1), f(e2)) + case Get(e1, e2) => g(f(e1), f(e2)) + case Concat(e1, e2) => g(f(e1), f(e2)) + case LetBinding(_, e1, e2) => g(f(e1), f(e2)) // 3-ary - case IfThenElse(e1, e2, e3) => g(g(f(e1), f(e2)), f(e3)) + case IfThenElse(e1, e2, e3) => g(g(f(e1), f(e2)), f(e3)) // n-ary - case RecNode(values) => values.map(_._2).map(f).foldLeft(default)(g) - case DictNode(map, hint) => - g(g(map.map(_._1).map(f).foldLeft(default)(g), map.map(_._2).map(f).foldLeft(default)(g)), hint match { - case PHmap(Some(e)) => f(e) - case _ => default - }) - case External(_, args) => args.map(f).foldLeft(default)(g) - case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") + case RecNode(values) => values.map(_._2).map(f).foldLeft(default)(g) + case DictNode(map, hint) => + g( + g(map.map(_._1).map(f).foldLeft(default)(g), map.map(_._2).map(f).foldLeft(default)(g)), + hint match { + case PHmap(Some(e)) => f(e) + case _ => default + } + ) + case External(_, args) => args.map(f).foldLeft(default)(g) + case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") } } @@ -145,7 +151,7 @@ private object SkipUnusedColumns extends Transformation { val new_ = tp.attrs.map(_.name: String).filter(!columnsCtx.getOrElse(x, Set()).contains(_)).toSet val skip = SetNode.fromSkipColsSet(old | new_) LetBinding(x, Load(path, tp, skip), run(e2)) - case _ => Rewriter.mapInner(run)(e) + case _ => Rewriter.mapInner(run)(e) } private def find(e: Exp)(implicit columnsCtx: Columns = Map()): Columns = e match { @@ -200,10 +206,10 @@ private object LowerToLLQL extends Transformation { case LetBinding(x, e1: Sum, e2) => run(LetBinding(x, sumToInitialise(e1)(ctx, Some(x)), e2)) case LetBinding(x, e1, e2) => LetBinding(x, e1, run(e2)(ctx ++ Map(x -> TypeInference.run(e1)), dest)) case IfThenElse(cond, e1, e2) => IfThenElse(cond, run(e1), run(e2)) - case Sum(key, value, e1, e2) => + case Sum(key, value, e1, e2) => val (_, ctxLocal) = TypeInference.sumInferTypeAndCtx(key, value, e1, e2) Sum(key, value, run(e1)(ctx, None), run(e2)(ctxLocal, dest)) - case _ => + case _ => dest match { case Some(dest) => sumBodyToLLQL(e, dest)(ctx) case None => Rewriter.mapInner(run)(e) @@ -215,25 +221,25 @@ private object LowerToLLQL extends Transformation { private def isUpdate(e: Exp)(implicit ctx: TypesCtx) = sumHint(e) match { case Some(_: PHmap) if cond(e) { case dict: DictNode => checkIsUnique(dict) } => false - case None | Some(_: PHmap | _: SmallVecDict | _: SmallVecDicts) => true - case Some(_: Vec) => false + case None | Some(_: PHmap | _: SmallVecDict | _: SmallVecDicts) => true + case Some(_: Vec) => false } private def sumHint(e: Exp)(implicit ctx: TypesCtx) = e match { case dict @ DictNode(map, _) if map.nonEmpty => (TypeInference.run(dict.getInnerDict): @unchecked) match { case DictType(_, _, hint) => Some(hint) } - case _ => None + case _ => None } - private def checkIsUnique(dict: DictNode) = cond(dict.getInnerDict) { - case DictNode(Seq((_: Unique, _)), _: PHmap) => true + private def checkIsUnique(dict: DictNode) = cond(dict.getInnerDict) { case DictNode(Seq((_: Unique, _)), _: PHmap) => + true } private def sumToInitialise(e: Sum)(implicit ctx: TypesCtx, dest: Option[Sym]) = e match { case Sum(k, v, e1, e2) => val (tpe, ctxLocal) = TypeInference.sumInferTypeAndCtx(k, v, e1, e2) - val packed = getSumBody(e2) match { + val packed = getSumBody(e2) match { case Promote(tsrt: TropicalSemiRingType, _) => TropicalSemiRingType.pack(tsrt, tpe) case _ => tpe } From 34fbefd1ad15fadeae90271f90c364a3574be3fa Mon Sep 17 00:00:00 2001 From: Amir Shaikhha Date: Fri, 30 Aug 2024 05:49:20 +0100 Subject: [PATCH 3/5] Restage type class added --- src/main/scala/sdql/ir/Restage.scala | 75 ++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/main/scala/sdql/ir/Restage.scala diff --git a/src/main/scala/sdql/ir/Restage.scala b/src/main/scala/sdql/ir/Restage.scala new file mode 100644 index 00000000..d71a69dc --- /dev/null +++ b/src/main/scala/sdql/ir/Restage.scala @@ -0,0 +1,75 @@ +package sdql +package ir + +/** + * The type-class interface of the Restage design pattern. + * + * Refer to the following paper for more details: Amir Shaikhha, "Restaging Domain-Specific Languages: A Flexible Design + * Pattern for Rapid Development of Optimizing Compilers", GPCE'24. + */ +trait Restage[T] { + def restage(e: T): (Seq[T], Seq[T] => T) = + children(e) -> factory(e) + def children(e: T): Seq[T] + def factory(e: T): Seq[T] => T +} + +object Restage { + type Fact[T] = Seq[T] => T + def unapply[T: Restage](e: T): Some[(Seq[T], Seq[T] => T)] = + Some(implicitly[Restage[T]].restage(e)) + + implicit object RestageExp extends Restage[Exp] { + def children(e: Exp): Seq[Exp] = e match { + // 0-ary + case _: Sym | _: Const | _: Load => Seq() + // 1-ary + case Neg(e) => Seq(e) + case FieldNode(e, _) => Seq(e) + case Promote(_, e) => Seq(e) + case RangeNode(e) => Seq(e) + case Unique(e) => Seq(e) + // 2-ary + case Add(e1, e2) => Seq(e1, e2) + case Mult(e1, e2) => Seq(e1, e2) + case Cmp(e1, e2, _) => Seq(e1, e2) + case Sum(_, _, e1, e2) => Seq(e1, e2) + case Get(e1, e2) => Seq(e1, e2) + case Concat(e1, e2) => Seq(e1, e2) + case LetBinding(_, e1, e2) => Seq(e1, e2) + // 3-ary + case IfThenElse(e1, e2, e3) => Seq(e1, e2, e3) + // n-ary + case RecNode(values) => values.map(_._2) + case DictNode(map, _) => map.flatMap(x => Seq(x._1, x._2)) + case External(_, args) => args + case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") + } + def factory(e: Exp): Seq[Exp] => Exp = e match { + // 0-ary + case _: Sym | _: Const | _: Load => _ => e + // 1-ary + case Neg(_) => seq => Neg(seq(0)) + case FieldNode(_, field) => seq => FieldNode(seq(0), field) + case Promote(tp, _) => seq => Promote(tp, seq(0)) + case RangeNode(_) => seq => RangeNode(seq(0)) + case Unique(_) => seq => Unique(seq(0)) + // 2-ary + case Add(_, _) => seq => Add(seq(0), seq(1)) + case Mult(_, _) => seq => Mult(seq(0), seq(1)) + case Cmp(_, _, cmp) => seq => Cmp(seq(0), seq(1), cmp) + case Sum(key, value, _, _) => seq => Sum(key, value, seq(0), seq(1)) + case Get(_, _) => seq => Get(seq(0), seq(1)) + case Concat(_, _) => seq => Concat(seq(0), seq(1)) + case LetBinding(x, _, _) => seq => LetBinding(x, seq(0), seq(1)) + // 3-ary + case IfThenElse(_, _, _) => seq => IfThenElse(seq(0), seq(1), seq(2)) + // n-ary + case RecNode(values) => seq => RecNode(values.zip(seq).map(vs => (vs._1._1, vs._2))) + case DictNode(map, hint) => + seq => DictNode((0 until map.length).map(i => seq(i * 2) -> seq(i * 2 + 1)).toSeq, hint) + case External(name, _) => seq => External(name, seq) + case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") + } + } +} From 3e84200671e5ee238fb91d307fa54d1d7848848d Mon Sep 17 00:00:00 2001 From: Amir Shaikhha Date: Fri, 30 Aug 2024 05:49:36 +0100 Subject: [PATCH 4/5] Rewriter uses the Restage type class --- .../scala/sdql/transformations/Rewriter.scala | 67 ++----------------- 1 file changed, 7 insertions(+), 60 deletions(-) diff --git a/src/main/scala/sdql/transformations/Rewriter.scala b/src/main/scala/sdql/transformations/Rewriter.scala index a7d910a3..a8004194 100644 --- a/src/main/scala/sdql/transformations/Rewriter.scala +++ b/src/main/scala/sdql/transformations/Rewriter.scala @@ -29,69 +29,16 @@ object Rewriter { LowerToLLQL def mapInner(f: Exp => Exp)(e: Exp): Exp = e match { - // 0-ary - case _: Sym | _: Const | _: Load => e - // 1-ary - case Neg(e) => Neg(f(e)) - case FieldNode(e, field) => FieldNode(f(e), field) - case Promote(tp, e) => Promote(tp, f(e)) - case RangeNode(e) => RangeNode(f(e)) - case Unique(e) => Unique(f(e)) - // 2-ary - case Add(e1, e2) => Add(f(e1), f(e2)) - case Mult(e1, e2) => Mult(f(e1), f(e2)) - case Cmp(e1, e2, cmp) => Cmp(f(e1), f(e2), cmp) - case Sum(key, value, e1, e2) => Sum(key, value, f(e1), f(e2)) - case Get(e1, e2) => Get(f(e1), f(e2)) - case Concat(e1, e2) => Concat(f(e1), f(e2)) - case LetBinding(x, e1, e2) => LetBinding(x, f(e1), f(e2)) - // 3-ary - case IfThenElse(e1, e2, e3) => IfThenElse(f(e1), f(e2), f(e3)) - // n-ary - case RecNode(values) => RecNode(values.map(v => (v._1, f(v._2)))) - case DictNode(map, hint) => - DictNode( - map.map(x => (f(x._1), f(x._2))), - hint match { - case PHmap(Some(e)) => PHmap(Some(f(e))) - case _ => hint - } - ) - case External(name, args) => External(name, args.map(f)) - case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") + case Restage(cs, fact) => fact(cs.map(f)) } def mapInnerReduce[T](f: Exp => T, g: (T, T) => T, default: T)(e: Exp): T = e match { - // 0-ary - case _: Sym | _: Const | _: Load => default - // 1-ary - case Neg(e) => f(e) - case FieldNode(e, _) => f(e) - case Promote(_, e) => f(e) - case RangeNode(e) => f(e) - case Unique(e) => f(e) - // 2-ary - case Add(e1, e2) => g(f(e1), f(e2)) - case Mult(e1, e2) => g(f(e1), f(e2)) - case Cmp(e1, e2, _) => g(f(e1), f(e2)) - case Sum(_, _, e1, e2) => g(f(e1), f(e2)) - case Get(e1, e2) => g(f(e1), f(e2)) - case Concat(e1, e2) => g(f(e1), f(e2)) - case LetBinding(_, e1, e2) => g(f(e1), f(e2)) - // 3-ary - case IfThenElse(e1, e2, e3) => g(g(f(e1), f(e2)), f(e3)) - // n-ary - case RecNode(values) => values.map(_._2).map(f).foldLeft(default)(g) - case DictNode(map, hint) => - g( - g(map.map(_._1).map(f).foldLeft(default)(g), map.map(_._2).map(f).foldLeft(default)(g)), - hint match { - case PHmap(Some(e)) => f(e) - case _ => default - } - ) - case External(_, args) => args.map(f).foldLeft(default)(g) - case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") + case Restage(cs, _) => + val fcs = cs.map(f) + if(fcs.isEmpty) + default + else + fcs.reduceLeft(g) } } From 5ce5791d9d01c2cdada204c34f91873e6aba8b1f Mon Sep 17 00:00:00 2001 From: Amir Shaikhha Date: Fri, 30 Aug 2024 12:10:11 +0100 Subject: [PATCH 5/5] Restage handles more cases, more cases rewritten with Restage * Restage handles LLQL IR * Restage fully handles DictNode * Rewritten part CppCodegen using Restage --- src/main/scala/sdql/backend/CppCodegen.scala | 35 ++-------------- src/main/scala/sdql/ir/Exp.scala | 2 +- src/main/scala/sdql/ir/Restage.scala | 40 +++++++++++++------ .../scala/sdql/transformations/Rewriter.scala | 4 +- 4 files changed, 34 insertions(+), 47 deletions(-) diff --git a/src/main/scala/sdql/backend/CppCodegen.scala b/src/main/scala/sdql/backend/CppCodegen.scala index 1532712f..59d03a0e 100644 --- a/src/main/scala/sdql/backend/CppCodegen.scala +++ b/src/main/scala/sdql/backend/CppCodegen.scala @@ -399,38 +399,9 @@ object CppCodegen { (readCols ++ readSize).mkString(",\n") } private def iterExps(e: Exp): Iterator[Exp] = - Iterator(e) ++ ( - e match { - // 0-ary - case _: Sym | _: Const | _: Load => Iterator() - // 1-ary - case Neg(e) => iterExps(e) - case FieldNode(e, _) => iterExps(e) - case Promote(_, e) => iterExps(e) - case RangeNode(e) => iterExps(e) - case Unique(e) => iterExps(e) - // 2-ary - case Add(e1, e2) => iterExps(e1) ++ iterExps(e2) - case Mult(e1, e2) => iterExps(e1) ++ iterExps(e2) - case Cmp(e1, e2, _) => iterExps(e1) ++ iterExps(e2) - case Sum(_, _, e1, e2) => iterExps(e1) ++ iterExps(e2) - case Get(e1, e2) => iterExps(e1) ++ iterExps(e2) - case Concat(e1, e2) => iterExps(e1) ++ iterExps(e2) - case LetBinding(_, e1, e2) => iterExps(e1) ++ iterExps(e2) - // 3-ary - case IfThenElse(e1, e2, e3) => iterExps(e1) ++ iterExps(e2) ++ iterExps(e3) - // n-ary - case RecNode(values) => values.map(_._2).flatMap(iterExps) - case DictNode(map, PHmap(Some(e))) => map.flatMap(x => iterExps(x._1) ++ iterExps(x._2)) ++ iterExps(e) - case DictNode(map, _) => map.flatMap(x => iterExps(x._1) ++ iterExps(x._2)) - case External(_, args) => args.flatMap(iterExps) - // LLQL - case Initialise(_, e) => iterExps(e) - case Update(e, _, _) => iterExps(e) - case Modify(e, _) => iterExps(e) - case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") - } - ) + Iterator(e) ++ (e match { + case Restage(cs, _) => cs.flatMap(iterExps) + }) private def cppPrintResult(tpe: Type): String = tpe match { case DictType(kt, vt, _: PHmap) => diff --git a/src/main/scala/sdql/ir/Exp.scala b/src/main/scala/sdql/ir/Exp.scala index 91839178..e15fe50c 100644 --- a/src/main/scala/sdql/ir/Exp.scala +++ b/src/main/scala/sdql/ir/Exp.scala @@ -219,6 +219,6 @@ object SingleDict { * The expressions below bridge the gap between functional and imperative styles, simplifying code generation in C++. */ sealed trait LLQL -case class Initialise(tpe: Type, e: Sum) extends Exp with LLQL +case class Initialise(tpe: Type, e: Exp) extends Exp with LLQL case class Update(e: Exp, agg: Aggregation, dest: Sym) extends Exp with LLQL case class Modify(e: Exp, dest: Sym) extends Exp with LLQL diff --git a/src/main/scala/sdql/ir/Restage.scala b/src/main/scala/sdql/ir/Restage.scala index d71a69dc..86b4257e 100644 --- a/src/main/scala/sdql/ir/Restage.scala +++ b/src/main/scala/sdql/ir/Restage.scala @@ -15,7 +15,6 @@ trait Restage[T] { } object Restage { - type Fact[T] = Seq[T] => T def unapply[T: Restage](e: T): Some[(Seq[T], Seq[T] => T)] = Some(implicitly[Restage[T]].restage(e)) @@ -41,8 +40,17 @@ object Restage { case IfThenElse(e1, e2, e3) => Seq(e1, e2, e3) // n-ary case RecNode(values) => values.map(_._2) - case DictNode(map, _) => map.flatMap(x => Seq(x._1, x._2)) + case DictNode(map, hint) => + val cs = map.flatMap(x => Seq(x._1, x._2)) + val hcs = hint match { + case PHmap(Some(e0)) => Seq(e0) + case _ => Seq() + } + cs ++ hcs case External(_, args) => args + case Initialise(_, e) => Seq(e) + case Update(e, _, _) => Seq(e) + case Modify(e, _) => Seq(e) case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") } def factory(e: Exp): Seq[Exp] => Exp = e match { @@ -55,20 +63,28 @@ object Restage { case RangeNode(_) => seq => RangeNode(seq(0)) case Unique(_) => seq => Unique(seq(0)) // 2-ary - case Add(_, _) => seq => Add(seq(0), seq(1)) - case Mult(_, _) => seq => Mult(seq(0), seq(1)) - case Cmp(_, _, cmp) => seq => Cmp(seq(0), seq(1), cmp) - case Sum(key, value, _, _) => seq => Sum(key, value, seq(0), seq(1)) - case Get(_, _) => seq => Get(seq(0), seq(1)) - case Concat(_, _) => seq => Concat(seq(0), seq(1)) - case LetBinding(x, _, _) => seq => LetBinding(x, seq(0), seq(1)) + case Add(_, _) => seq => Add(seq(0), seq(1)) + case Mult(_, _) => seq => Mult(seq(0), seq(1)) + case Cmp(_, _, cmp) => seq => Cmp(seq(0), seq(1), cmp) + case Sum(key, value, _, _) => seq => Sum(key, value, seq(0), seq(1)) + case Get(_, _) => seq => Get(seq(0), seq(1)) + case Concat(_, _) => seq => Concat(seq(0), seq(1)) + case LetBinding(x, _, _) => seq => LetBinding(x, seq(0), seq(1)) // 3-ary - case IfThenElse(_, _, _) => seq => IfThenElse(seq(0), seq(1), seq(2)) + case IfThenElse(_, _, _) => seq => IfThenElse(seq(0), seq(1), seq(2)) // n-ary case RecNode(values) => seq => RecNode(values.zip(seq).map(vs => (vs._1._1, vs._2))) case DictNode(map, hint) => - seq => DictNode((0 until map.length).map(i => seq(i * 2) -> seq(i * 2 + 1)).toSeq, hint) - case External(name, _) => seq => External(name, seq) + seq => + val nhint = hint match { + case PHmap(Some(_)) => PHmap(Some(seq.last)) + case _ => hint + } + DictNode(map.indices.map(i => seq(i * 2) -> seq(i * 2 + 1)).toSeq, nhint) + case External(name, _) => seq => External(name, seq) + case Initialise(tpe, _) => seq => Initialise(tpe, seq(0)) + case Update(_, agg, dest) => seq => Update(seq(0), agg, dest) + case Modify(_, dest) => seq => Modify(seq(0), dest) case _ => raise(f"unhandled ${e.simpleName} in\n${e.prettyPrint}") } } diff --git a/src/main/scala/sdql/transformations/Rewriter.scala b/src/main/scala/sdql/transformations/Rewriter.scala index a8004194..c87d6cd3 100644 --- a/src/main/scala/sdql/transformations/Rewriter.scala +++ b/src/main/scala/sdql/transformations/Rewriter.scala @@ -33,9 +33,9 @@ object Rewriter { } def mapInnerReduce[T](f: Exp => T, g: (T, T) => T, default: T)(e: Exp): T = e match { - case Restage(cs, _) => + case Restage(cs, _) => val fcs = cs.map(f) - if(fcs.isEmpty) + if (fcs.isEmpty) default else fcs.reduceLeft(g)