diff --git a/amm/src/main/scala/ammonite/AmmoniteMain.scala b/amm/src/main/scala/ammonite/AmmoniteMain.scala index d1311ad2e..b96474656 100644 --- a/amm/src/main/scala/ammonite/AmmoniteMain.scala +++ b/amm/src/main/scala/ammonite/AmmoniteMain.scala @@ -109,7 +109,12 @@ object AmmoniteMain{ false case (None, head :: rest) => - val success = runner.runScript(os.Path(head, os.pwd), rest) + val success = runner.runScript( + os.Path(head, os.pwd), + allowPositional = !cliConfig.scripts.noPositionalArgs.value, + allowRepeats = cliConfig.scripts.allowRepeatArgs.value, + rest + ) success } } diff --git a/amm/src/main/scala/ammonite/Main.scala b/amm/src/main/scala/ammonite/Main.scala index 720cf967c..31c990541 100644 --- a/amm/src/main/scala/ammonite/Main.scala +++ b/amm/src/main/scala/ammonite/Main.scala @@ -244,12 +244,14 @@ case class Main(predefCode: String = "", * of `args` and a map of keyword `kwargs` to pass to that file. */ def runScript(path: os.Path, - scriptArgs: Seq[String]) + scriptArgs: Seq[String], + allowPositional: Boolean = true, + allowRepeats: Boolean = false) : (Res[Any], Seq[(Watchable, Long)]) = { instantiateInterpreter() match{ case Right(interp) => - val result = main.Scripts.runScript(wd, path, interp, scriptArgs) + val result = main.Scripts.runScript(wd, path, interp, allowPositional, allowRepeats, scriptArgs) (result, interp.watchedValues.toSeq) case Left(problems) => problems } diff --git a/amm/src/main/scala/ammonite/MainRunner.scala b/amm/src/main/scala/ammonite/MainRunner.scala index 3f00c8d3c..e485a0316 100644 --- a/amm/src/main/scala/ammonite/MainRunner.scala +++ b/amm/src/main/scala/ammonite/MainRunner.scala @@ -60,11 +60,11 @@ class MainRunner(cliConfig: Config, } } - def runScript(scriptPath: os.Path, scriptArgs: List[String]) = + def runScript(scriptPath: os.Path, allowPositional: Boolean, allowRepeats: Boolean, scriptArgs: List[String]) = watchLoop( isRepl = false, printing = true, - _.runScript(scriptPath, scriptArgs) + _.runScript(scriptPath, scriptArgs, allowPositional = allowPositional, allowRepeats = allowRepeats) ) def runCode(code: String) = watchLoop(isRepl = false, printing = false, _.runCode(code)) @@ -184,4 +184,4 @@ object MainRunner{ } -} \ No newline at end of file +} diff --git a/amm/src/main/scala/ammonite/main/Config.scala b/amm/src/main/scala/ammonite/main/Config.scala index 10792451f..c4583a7a9 100644 --- a/amm/src/main/scala/ammonite/main/Config.scala +++ b/amm/src/main/scala/ammonite/main/Config.scala @@ -6,6 +6,7 @@ import ammonite.repl.tools.Util.PathRead case class Config(core: Config.Core, predef: Config.Predef, repl: Config.Repl, + scripts: Config.Scripts, rest: String*) object Config{ @@ -95,5 +96,20 @@ object Config{ ) implicit val replParser = ParserForClass[Repl] + @main + case class Scripts( + @arg( + name = "no-positional-args", + doc = "Disallow positional arguments for scripts" + ) + noPositionalArgs: Flag, + @arg( + name = "allow-repeat-args", + doc = "Allow repeated arguments for scripts" + ) + allowRepeatArgs: Flag + ) + implicit val scriptsParser = ParserForClass[Scripts] + val parser = mainargs.ParserForClass[Config] } diff --git a/amm/src/main/scala/ammonite/main/Scripts.scala b/amm/src/main/scala/ammonite/main/Scripts.scala index 021868e5b..70d4b10f0 100644 --- a/amm/src/main/scala/ammonite/main/Scripts.scala +++ b/amm/src/main/scala/ammonite/main/Scripts.scala @@ -14,6 +14,8 @@ object Scripts { def runScript(wd: os.Path, path: os.Path, interp: ammonite.interp.Interpreter, + allowPositional: Boolean = true, + allowRepeats: Boolean = false, scriptArgs: Seq[String] = Nil) = { interp.watch(path) val (pkg, wrapper) = Util.pathToPackageWrapper(Seq(), path relativeTo wd) @@ -106,8 +108,8 @@ object Scripts { }else mainargs.Invoker.runMains( parser.mains, scriptArgs, - allowPositional = true, - allowRepeats = false + allowPositional = allowPositional, + allowRepeats = allowRepeats ) match{ case Left(earlyError) => Res.Failure(mainargs.Renderer.renderEarlyError(earlyError)) diff --git a/amm/src/test/scala/ammonite/interp/CachingTests.scala b/amm/src/test/scala/ammonite/interp/CachingTests.scala index c99428c9c..fff6b86c1 100644 --- a/amm/src/test/scala/ammonite/interp/CachingTests.scala +++ b/amm/src/test/scala/ammonite/interp/CachingTests.scala @@ -17,7 +17,7 @@ object CachingTests extends TestSuite{ path: os.Path, interp: ammonite.interp.Interpreter, scriptArgs: Seq[String] = Nil) = - Scripts.runScript(wd, path, interp, scriptArgs) match { + Scripts.runScript(wd, path, interp, scriptArgs = scriptArgs) match { case Res.Success(_) => case Res.Skip => case Res.Exception(t, s) => throw new Exception(s"Error running script: $s", t) @@ -231,7 +231,7 @@ object CachingTests extends TestSuite{ val storage = new Storage.Folder(storageFolder) def runScript(script: os.Path, expectedCount: Int) = { val interp = createTestInterp(storage) - val res = Scripts.runScript(script / os.up, script, interp, Nil) + val res = Scripts.runScript(script / os.up, script, interp) val count = interp.compilationCount assert(count == expectedCount) diff --git a/amm/src/test/scala/ammonite/main/MainTests.scala b/amm/src/test/scala/ammonite/main/MainTests.scala index 474da8e72..42918851b 100755 --- a/amm/src/test/scala/ammonite/main/MainTests.scala +++ b/amm/src/test/scala/ammonite/main/MainTests.scala @@ -13,6 +13,9 @@ class MainTests extends TestSuite{ def exec(p: String, args: String*) = new InProcessMainMethodRunner(InProcessMainMethodRunner.base / 'mains / p, Nil, args) + def execPreArgs(preArgs: String*)(p: String, args: String*) = + new InProcessMainMethodRunner(InProcessMainMethodRunner.base / 'mains / p, preArgs.toList, args) + def stripInvisibleMargin(s: String): String = { val lines = Predef.augmentString(s).lines.toArray val leftMargin = lines.filter(_.trim.nonEmpty).map(_.takeWhile(_ == ' ').length).min @@ -287,6 +290,27 @@ class MainTests extends TestSuite{ assert(Predef.augmentString(evaled.err).lines.length < 20) } + test("noPositionalArgs") { + val evaled = execPreArgs("--no-positional-args")("Args.sc", "1", "moo") + assert(!evaled.success) + + assert(evaled.err.contains( + Util.normalizeNewlines( + s"""Missing arguments: -i -s + |Unknown arguments: "1" "moo" + |$argsUsageMsg""".stripMargin + ) + )) + } + test("allowRepeatArgs") { + val evaled = execPreArgs("--allow-repeat-args")("Args.sc", "1", "Moo", "-i", "3") + assert(evaled.success) + assert( + evaled.out == ("\"Hello! MooMooMoo Ammonite.\"" + Util.newLine) || + // For some reason, on windows CI machines the repo gets clone as lowercase (???) + evaled.out == ("\"Hello! MooMooMoo ammonite.\"" + Util.newLine) + ) + } } } }