diff --git a/cask/src/cask/internal/DispatchTrie.scala b/cask/src/cask/internal/DispatchTrie.scala index cc9f7e7252..351ce26217 100644 --- a/cask/src/cask/internal/DispatchTrie.scala +++ b/cask/src/cask/internal/DispatchTrie.scala @@ -38,17 +38,25 @@ object DispatchTrie{ validateGroup(groupTerminals, groupContinuations) } + val dynamicChildren = continuations.filter(_._1.startsWith(":")) + .flatMap(_._2).toIndexedSeq + DispatchTrie[T]( - current = terminals.headOption.map(x => x._2 -> x._3), - children = continuations + current = terminals.headOption + .map{ case (path, value, capturesSubpath) => + val argNames = path.filter(_.startsWith(":")).map(_.drop(1)).toVector + (value, capturesSubpath, argNames) + }, + staticChildren = continuations + .filter(!_._1.startsWith(":")) .map{ case (k, vs) => (k, construct(index + 1, vs)(validationGroups))} - .toMap + .toMap, + dynamicChildren = if (dynamicChildren.isEmpty) None else Some(construct(index + 1, dynamicChildren)(validationGroups)) ) } def validateGroup[T, V](terminals: collection.Seq[(collection.Seq[String], T, Boolean, V)], continuations: mutable.Map[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean, V)]]) = { - val wildcards = continuations.filter(_._1(0) == ':') def renderTerminals = terminals .map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"} @@ -65,12 +73,6 @@ object DispatchTrie{ ) } - if (wildcards.size >= 1 && continuations.size > 1) { - throw new Exception( - s"Routes overlap with wildcards: $renderContinuations" - ) - } - if (terminals.headOption.exists(_._3) && continuations.size == 1) { throw new Exception( s"Routes overlap with subpath capture: $renderTerminals, $renderContinuations" @@ -88,32 +90,37 @@ object DispatchTrie{ * segments starting with `:`) and any remaining un-used path segments * (only when `current._2 == true`, indicating this route allows trailing * segments) + * current = (value, captures subpaths, argument names) */ -case class DispatchTrie[T](current: Option[(T, Boolean)], - children: Map[String, DispatchTrie[T]]){ +case class DispatchTrie[T]( + current: Option[(T, Boolean, Vector[String])], + staticChildren: Map[String, DispatchTrie[T]], + dynamicChildren: Option[DispatchTrie[T]] +) { + final def lookup(remainingInput: List[String], - bindings: Map[String, String]) + bindings: Vector[String]) : Option[(T, Map[String, String], Seq[String])] = { - remainingInput match{ + remainingInput match { case Nil => - current.map(x => (x._1, bindings, Nil)) + current.map(x => (x._1, x._3.zip(bindings).toMap, Nil)) case head :: rest if current.exists(_._2) => - current.map(x => (x._1, bindings, head :: rest)) + current.map(x => (x._1, x._3.zip(bindings).toMap, head :: rest)) case head :: rest => - if (children.size == 1 && children.keys.head.startsWith(":")){ - children.values.head.lookup(rest, bindings + (children.keys.head.drop(1) -> head)) - }else{ - children.get(head) match{ - case None => None - case Some(continuation) => continuation.lookup(rest, bindings) - } + staticChildren.get(head) match { + case Some(continuation) => continuation.lookup(rest, bindings) + case None => + dynamicChildren match { + case Some(continuation) => continuation.lookup(rest, bindings :+ head) + case None => None + } } - } } def map[V](f: T => V): DispatchTrie[V] = DispatchTrie( - current.map{case (t, v) => (f(t), v)}, - children.map { case (k, v) => (k, v.map(f))} + current.map{case (t, v, a) => (f(t), v, a)}, + staticChildren.map { case (k, v) => (k, v.map(f))}, + dynamicChildren.map { case v => v.map(f)}, ) } diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index 6ac6e80e57..15c365666b 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -106,7 +106,7 @@ object Main{ .map(java.net.URLDecoder.decode(_, "UTF-8")) .toList - dispatchTrie.lookup(decodedSegments, Map()) match { + dispatchTrie.lookup(decodedSegments, Vector()) match { case None => Main.writeResponse(exchange, handleNotFound(Request(exchange, decodedSegments))) case Some((methodMap, routeBindings, remaining)) => methodMap.get(effectiveMethod) match { diff --git a/cask/test/src/test/cask/DispatchTrieTests.scala b/cask/test/src/test/cask/DispatchTrieTests.scala index 1da0df3286..b7b2ce04ff 100644 --- a/cask/test/src/test/cask/DispatchTrieTests.scala +++ b/cask/test/src/test/cask/DispatchTrieTests.scala @@ -11,9 +11,9 @@ object DispatchTrieTests extends TestSuite { )(Seq(_)) assert( - x.lookup(List("hello"), Map()) == Some((1, Map(), Nil)), - x.lookup(List("hello", "world"), Map()) == None, - x.lookup(List("world"), Map()) == None + x.lookup(List("hello"), Vector()) == Some((1, Map(), Nil)), + x.lookup(List("hello", "world"), Vector()) == None, + x.lookup(List("world"), Vector()) == None ) } "nested" - { @@ -24,11 +24,11 @@ object DispatchTrieTests extends TestSuite { ) )(Seq(_)) assert( - x.lookup(List("hello", "world"), Map()) == Some((1, Map(), Nil)), - x.lookup(List("hello", "cow"), Map()) == Some((2, Map(), Nil)), - x.lookup(List("hello"), Map()) == None, - x.lookup(List("hello", "moo"), Map()) == None, - x.lookup(List("hello", "world", "moo"), Map()) == None + x.lookup(List("hello", "world"), Vector()) == Some((1, Map(), Nil)), + x.lookup(List("hello", "cow"), Vector()) == Some((2, Map(), Nil)), + x.lookup(List("hello"), Vector()) == None, + x.lookup(List("hello", "moo"), Vector()) == None, + x.lookup(List("hello", "world", "moo"), Vector()) == None ) } "bindings" - { @@ -36,11 +36,11 @@ object DispatchTrieTests extends TestSuite { Seq((Vector(":hello", ":world"), 1, false)) )(Seq(_)) assert( - x.lookup(List("hello", "world"), Map()) == Some((1, Map("hello" -> "hello", "world" -> "world"), Nil)), - x.lookup(List("world", "hello"), Map()) == Some((1, Map("hello" -> "world", "world" -> "hello"), Nil)), + x.lookup(List("hello", "world"), Vector()) == Some((1, Map("hello" -> "hello", "world" -> "world"), Nil)), + x.lookup(List("world", "hello"), Vector()) == Some((1, Map("hello" -> "world", "world" -> "hello"), Nil)), - x.lookup(List("hello", "world", "cow"), Map()) == None, - x.lookup(List("hello"), Map()) == None + x.lookup(List("hello", "world", "cow"), Vector()) == None, + x.lookup(List("hello"), Vector()) == None ) } @@ -50,35 +50,21 @@ object DispatchTrieTests extends TestSuite { )(Seq(_)) assert( - x.lookup(List("hello", "world"), Map()) == Some((1,Map(), Seq("world"))), - x.lookup(List("hello", "world", "cow"), Map()) == Some((1,Map(), Seq("world", "cow"))), - x.lookup(List("hello"), Map()) == Some((1,Map(), Seq())), - x.lookup(List(), Map()) == None + x.lookup(List("hello", "world"), Vector()) == Some((1,Map(), Seq("world"))), + x.lookup(List("hello", "world", "cow"), Vector()) == Some((1,Map(), Seq("world", "cow"))), + x.lookup(List("hello"), Vector()) == Some((1,Map(), Seq())), + x.lookup(List(), Vector()) == None ) } - "errors" - { + "wildcards" - { test - { DispatchTrie.construct(0, Seq( (Vector("hello", ":world"), 1, false), - (Vector("hello", "world"), 2, false) + (Vector("hello", "world"), 1, false) ) )(Seq(_)) - - val ex = intercept[Exception]{ - DispatchTrie.construct(0, - Seq( - (Vector("hello", ":world"), 1, false), - (Vector("hello", "world"), 1, false) - ) - )(Seq(_)) - } - - assert( - ex.getMessage == - "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world" - ) } test - { DispatchTrie.construct(0, @@ -87,21 +73,9 @@ object DispatchTrieTests extends TestSuite { (Vector("hello", "world", "omg"), 2, false) ) )(Seq(_)) - - val ex = intercept[Exception]{ - DispatchTrie.construct(0, - Seq( - (Vector("hello", ":world"), 1, false), - (Vector("hello", "world", "omg"), 1, false) - ) - )(Seq(_)) - } - - assert( - ex.getMessage == - "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world/omg" - ) } + } + "errors" - { test - { DispatchTrie.construct(0, Seq( @@ -143,7 +117,7 @@ object DispatchTrieTests extends TestSuite { assert( ex.getMessage == - "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/:cow" + "More than one endpoint has the same path: 1 /hello/:world, 1 /hello/:cow" ) } test - { diff --git a/example/queryParams/app/src/QueryParams.scala b/example/queryParams/app/src/QueryParams.scala index 5e73574e82..5e79409101 100644 --- a/example/queryParams/app/src/QueryParams.scala +++ b/example/queryParams/app/src/QueryParams.scala @@ -2,7 +2,7 @@ package app object QueryParams extends cask.MainRoutes{ @cask.get("/article/:articleId") // Mandatory query param, e.g. HOST/article/foo?param=bar - def getArticle(articleId: Int, param: String) = { + def getArticle(articleId: Int, param: String) = { s"Article $articleId $param" } @@ -31,5 +31,20 @@ object QueryParams extends cask.MainRoutes{ s"User $userName " + params.value } + @cask.get("/statics/foo") + def getStatic() = { + "static route takes precedence" + } + + @cask.get("/statics/:foo") + def getDynamics(foo: String) = { + s"dynamic route $foo" + } + + @cask.get("/statics/bar") + def getStatic2() = { + "another static route" + } + initialize() } diff --git a/example/queryParams/app/test/src/ExampleTests.scala b/example/queryParams/app/test/src/ExampleTests.scala index 03ae03371f..ca9b4b0752 100644 --- a/example/queryParams/app/test/src/ExampleTests.scala +++ b/example/queryParams/app/test/src/ExampleTests.scala @@ -90,6 +90,16 @@ object ExampleTests extends TestSuite{ res3 == "User lihaoyi Map(unknown1 -> WrappedArray(123), unknown2 -> WrappedArray(abc))" || res3 == "User lihaoyi Map(unknown1 -> ArraySeq(123), unknown2 -> ArraySeq(abc))" ) + + assert( + requests.get(s"$host/statics/foo").text() == "static route takes precedence" + ) + assert( + requests.get(s"$host/statics/hello").text() == "dynamic route hello" + ) + assert( + requests.get(s"$host/statics/bar").text() == "another static route" + ) } } }