Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script argument control flags #1268

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion amm/src/main/scala/ammonite/AmmoniteMain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ object AmmoniteMain{
false

case (None, head :: rest) =>
val success = runner.runScript(os.Path(head, os.pwd), rest)
val success = runner.runScript(
os.Path(head, os.pwd),
allowPositional = !cliConfig.scripts.noPositionalArgs.value,
allowRepeats = cliConfig.scripts.allowRepeatArgs.value,
rest
)
success
}
}
Expand Down
6 changes: 4 additions & 2 deletions amm/src/main/scala/ammonite/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,14 @@ case class Main(predefCode: String = "",
* of `args` and a map of keyword `kwargs` to pass to that file.
*/
def runScript(path: os.Path,
scriptArgs: Seq[String])
scriptArgs: Seq[String],
allowPositional: Boolean = true,
allowRepeats: Boolean = false)
: (Res[Any], Seq[(Watchable, Long)]) = {

instantiateInterpreter() match{
case Right(interp) =>
val result = main.Scripts.runScript(wd, path, interp, scriptArgs)
val result = main.Scripts.runScript(wd, path, interp, allowPositional, allowRepeats, scriptArgs)
(result, interp.watchedValues.toSeq)
case Left(problems) => problems
}
Expand Down
6 changes: 3 additions & 3 deletions amm/src/main/scala/ammonite/MainRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ class MainRunner(cliConfig: Config,
}
}

def runScript(scriptPath: os.Path, scriptArgs: List[String]) =
def runScript(scriptPath: os.Path, allowPositional: Boolean, allowRepeats: Boolean, scriptArgs: List[String]) =
watchLoop(
isRepl = false,
printing = true,
_.runScript(scriptPath, scriptArgs)
_.runScript(scriptPath, scriptArgs, allowPositional = allowPositional, allowRepeats = allowRepeats)
)

def runCode(code: String) = watchLoop(isRepl = false, printing = false, _.runCode(code))
Expand Down Expand Up @@ -184,4 +184,4 @@ object MainRunner{
}


}
}
16 changes: 16 additions & 0 deletions amm/src/main/scala/ammonite/main/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import ammonite.repl.tools.Util.PathRead
case class Config(core: Config.Core,
predef: Config.Predef,
repl: Config.Repl,
scripts: Config.Scripts,
rest: String*)

object Config{
Expand Down Expand Up @@ -95,5 +96,20 @@ object Config{
)
implicit val replParser = ParserForClass[Repl]

@main
case class Scripts(
@arg(
name = "no-positional-args",
doc = "Disallow positional arguments for scripts"
)
noPositionalArgs: Flag,
@arg(
name = "allow-repeat-args",
doc = "Allow repeated arguments for scripts"
)
allowRepeatArgs: Flag
)
implicit val scriptsParser = ParserForClass[Scripts]

val parser = mainargs.ParserForClass[Config]
}
6 changes: 4 additions & 2 deletions amm/src/main/scala/ammonite/main/Scripts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ object Scripts {
def runScript(wd: os.Path,
path: os.Path,
interp: ammonite.interp.Interpreter,
allowPositional: Boolean = true,
allowRepeats: Boolean = false,
scriptArgs: Seq[String] = Nil) = {
interp.watch(path)
val (pkg, wrapper) = Util.pathToPackageWrapper(Seq(), path relativeTo wd)
Expand Down Expand Up @@ -106,8 +108,8 @@ object Scripts {
}else mainargs.Invoker.runMains(
parser.mains,
scriptArgs,
allowPositional = true,
allowRepeats = false
allowPositional = allowPositional,
allowRepeats = allowRepeats
) match{
case Left(earlyError) =>
Res.Failure(mainargs.Renderer.renderEarlyError(earlyError))
Expand Down
4 changes: 2 additions & 2 deletions amm/src/test/scala/ammonite/interp/CachingTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ object CachingTests extends TestSuite{
path: os.Path,
interp: ammonite.interp.Interpreter,
scriptArgs: Seq[String] = Nil) =
Scripts.runScript(wd, path, interp, scriptArgs) match {
Scripts.runScript(wd, path, interp, scriptArgs = scriptArgs) match {
case Res.Success(_) =>
case Res.Skip =>
case Res.Exception(t, s) => throw new Exception(s"Error running script: $s", t)
Expand Down Expand Up @@ -231,7 +231,7 @@ object CachingTests extends TestSuite{
val storage = new Storage.Folder(storageFolder)
def runScript(script: os.Path, expectedCount: Int) = {
val interp = createTestInterp(storage)
val res = Scripts.runScript(script / os.up, script, interp, Nil)
val res = Scripts.runScript(script / os.up, script, interp)

val count = interp.compilationCount
assert(count == expectedCount)
Expand Down
24 changes: 24 additions & 0 deletions amm/src/test/scala/ammonite/main/MainTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class MainTests extends TestSuite{
def exec(p: String, args: String*) =
new InProcessMainMethodRunner(InProcessMainMethodRunner.base / 'mains / p, Nil, args)

def execPreArgs(preArgs: String*)(p: String, args: String*) =
new InProcessMainMethodRunner(InProcessMainMethodRunner.base / 'mains / p, preArgs.toList, args)

def stripInvisibleMargin(s: String): String = {
val lines = Predef.augmentString(s).lines.toArray
val leftMargin = lines.filter(_.trim.nonEmpty).map(_.takeWhile(_ == ' ').length).min
Expand Down Expand Up @@ -287,6 +290,27 @@ class MainTests extends TestSuite{
assert(Predef.augmentString(evaled.err).lines.length < 20)

}
test("noPositionalArgs") {
val evaled = execPreArgs("--no-positional-args")("Args.sc", "1", "moo")
assert(!evaled.success)

assert(evaled.err.contains(
Util.normalizeNewlines(
s"""Missing arguments: -i <int> -s <str>
|Unknown arguments: "1" "moo"
|$argsUsageMsg""".stripMargin
)
))
}
test("allowRepeatArgs") {
val evaled = execPreArgs("--allow-repeat-args")("Args.sc", "1", "Moo", "-i", "3")
assert(evaled.success)
assert(
evaled.out == ("\"Hello! MooMooMoo Ammonite.\"" + Util.newLine) ||
// For some reason, on windows CI machines the repo gets clone as lowercase (???)
evaled.out == ("\"Hello! MooMooMoo ammonite.\"" + Util.newLine)
)
}
}
}
}
Expand Down