diff --git a/Cargo.toml b/Cargo.toml index b6cd275..d780807 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,13 +22,19 @@ paste = "1.0.15" ctor = "0.2.8" once_cell = "1.20.2" clap = { version = "4.5.21", features = ["derive"] } -assert_cmd = "2.0.16" +shlex = "1.3.0" [dev-dependencies] rand_distr = "0.4.3" tempfile = "3.3" ordered-float = "4.3.0" +predicates = "3.1.2" +assert_cmd = "2.0.16" [[bin]] name = "runner_test_custom_args" path = "tests/bin/runner_test_custom_args.rs" + +[[bin]] +name = "runner_test_debug" +path = "tests/bin/runner_test_debug.rs" diff --git a/examples/runner/main.rs b/examples/runner/main.rs index ce08e16..813a9c3 100644 --- a/examples/runner/main.rs +++ b/examples/runner/main.rs @@ -1,23 +1,40 @@ use clap::Args; use ixa::runner::run_with_custom_args; +use ixa::ContextPeopleExt; #[derive(Args, Debug)] -struct Extra { - #[arg(short, long)] - foo: bool, +struct CustomArgs { + #[arg(short = 'p', long)] + starting_population: Option, } fn main() { - // Try running this with `cargo run --example runner -- --seed 42` - run_with_custom_args(|context, args, extra: Option| { - context.add_plan(1.0, |_| { - println!("Hello, world!"); - }); - println!("{}", args.random_seed); - if let Some(extra) = extra { - println!("{}", extra.foo); + // Try running the following: + // cargo run --example runner -- --seed 42 + // cargo run --example runner -- --starting-population 5 + // cargo run --example runner -- -p 5 --debugger + let context = run_with_custom_args(|context, args, custom_args: Option| { + println!("Setting random seed to {}", args.random_seed); + + // If an initial population was provided, add each person + if let Some(custom_args) = custom_args { + if let Some(population) = custom_args.starting_population { + for _ in 0..population { + context.add_person(()).unwrap(); + } + } } + + context.add_plan(2.0, |context| { + println!("Adding two people at t=2"); + context.add_person(()).unwrap(); + context.add_person(()).unwrap(); + }); + Ok(()) }) .unwrap(); + + let final_count = context.get_current_population(); + println!("Simulation complete. The number of people is: {final_count}"); } diff --git a/src/context.rs b/src/context.rs index 14c8529..0009808 100644 --- a/src/context.rs +++ b/src/context.rs @@ -210,6 +210,12 @@ impl Context { self.plan_queue.cancel_plan(id); } + #[doc(hidden)] + #[allow(dead_code)] + pub(crate) fn remaining_plan_count(&self) -> usize { + self.plan_queue.remaining_plan_count() + } + /// Add a `Callback` to the queue to be executed before the next plan pub fn queue_callback(&mut self, callback: impl FnOnce(&mut Context) + 'static) { self.callback_queue.push_back(Box::new(callback)); diff --git a/src/debugger.rs b/src/debugger.rs new file mode 100644 index 0000000..0c21848 --- /dev/null +++ b/src/debugger.rs @@ -0,0 +1,306 @@ +use crate::Context; +use crate::ContextPeopleExt; +use crate::IxaError; +use clap::value_parser; +use clap::{Arg, ArgMatches, Command}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::io::Write; + +trait DebuggerCommand { + /// Handle the command and any inputs; returning true will exit the debugger + fn handle( + &self, + context: &mut Context, + matches: &ArgMatches, + ) -> Result<(bool, Option), String>; + fn about(&self) -> &'static str; + fn extend(&self, subcommand: Command) -> Command { + subcommand + } +} + +struct DebuggerRepl { + commands: HashMap<&'static str, Box>, + output: RefCell>, +} + +impl DebuggerRepl { + fn new(output: Box) -> Self { + DebuggerRepl { + commands: HashMap::new(), + output: RefCell::new(output), + } + } + + fn register_command(&mut self, name: &'static str, handler: Box) { + self.commands.insert(name, handler); + } + + fn get_command(&self, name: &str) -> Option<&dyn DebuggerCommand> { + self.commands.get(name).map(|command| &**command) + } + + fn writeln(&self, formatted_string: &str) { + let mut output = self.output.borrow_mut(); + writeln!(output, "{formatted_string}") + .map_err(|e| e.to_string()) + .unwrap(); + output.flush().unwrap(); + } + + fn build_cli(&self) -> Command { + let mut cli = Command::new("repl") + .multicall(true) + .arg_required_else_help(true) + .subcommand_required(true) + .subcommand_value_name("DEBUGGER") + .subcommand_help_heading("IXA DEBUGGER") + .help_template("{all-args}"); + + for (name, handler) in &self.commands { + let subcommand = + handler.extend(Command::new(*name).about(handler.about()).help_template( + "{about-with-newline}\n{usage-heading}\n {usage}\n\n{all-args}{after-help}", + )); + cli = cli.subcommand(subcommand); + } + + cli + } + + fn process_line(&self, l: &str, context: &mut Context) -> Result { + let args = shlex::split(l).ok_or("Error splitting lines")?; + let matches = self + .build_cli() + .try_get_matches_from(args) + .map_err(|e| e.to_string())?; + + if let Some((command, sub_matches)) = matches.subcommand() { + // If the provided command is known, run its handler + if let Some(handler) = self.get_command(command) { + let (quit, output) = handler.handle(context, sub_matches)?; + if let Some(output) = output { + self.writeln(&output); + } + return Ok(quit); + } + // Unexpected command: print an error + return Err(format!("Unknown command: {command}")); + } + + unreachable!("subcommand required"); + } +} + +/// Returns the current population of the simulation +struct PopulationCommand; +impl DebuggerCommand for PopulationCommand { + fn about(&self) -> &'static str { + "Get the total number of people" + } + fn handle( + &self, + context: &mut Context, + _matches: &ArgMatches, + ) -> Result<(bool, Option), String> { + let output = format!("{}", context.get_current_population()); + Ok((false, Some(output))) + } +} + +/// Adds a new debugger breakpoint at t +struct NextCommand; +impl DebuggerCommand for NextCommand { + fn about(&self) -> &'static str { + "Continue until the given time and then pause again" + } + fn extend(&self, subcommand: Command) -> Command { + subcommand.arg( + Arg::new("t") + .help("The next breakpoint (e.g., 4.2)") + .value_parser(value_parser!(f64)) + .required(true), + ) + } + fn handle( + &self, + context: &mut Context, + matches: &ArgMatches, + ) -> Result<(bool, Option), String> { + let t = *matches.get_one::("t").unwrap(); + context.schedule_debugger(t); + Ok((true, None)) + } +} + +/// Exits the debugger and continues the simulation +struct ContinueCommand; +impl DebuggerCommand for ContinueCommand { + fn about(&self) -> &'static str { + "Continue the simulation and exit the debugger" + } + fn handle( + &self, + _context: &mut Context, + _matches: &ArgMatches, + ) -> Result<(bool, Option), String> { + Ok((true, None)) + } +} + +// Assemble all the commands +fn build_repl(output: W) -> DebuggerRepl { + let mut repl = DebuggerRepl::new(Box::new(output)); + + repl.register_command("population", Box::new(PopulationCommand)); + repl.register_command("next", Box::new(NextCommand)); + repl.register_command("continue", Box::new(ContinueCommand)); + + repl +} + +// Helper function to read a line from stdin +fn readline(t: f64) -> Result { + write!(std::io::stdout(), "t={t} $ ").map_err(|e| e.to_string())?; + std::io::stdout().flush().map_err(|e| e.to_string())?; + let mut buffer = String::new(); + std::io::stdin() + .read_line(&mut buffer) + .map_err(|e| e.to_string())?; + Ok(buffer) +} + +/// Starts the debugger and pauses execution +fn start_debugger(context: &mut Context) -> Result<(), IxaError> { + let t = context.get_current_time(); + let repl = build_repl(std::io::stdout()); + println!("Debugging simulation at t={t}"); + loop { + let line = readline(t).expect("Error reading input"); + let line = line.trim(); + if line.is_empty() { + continue; + } + + match repl.process_line(line, context) { + Ok(quit) => { + if quit { + break; + } + } + Err(err) => { + write!(std::io::stdout(), "{err}").map_err(|e| e.to_string())?; + std::io::stdout().flush().unwrap(); + } + } + } + Ok(()) +} + +pub trait ContextDebugExt { + /// Schedule the simulation to pause at time t and start the debugger. + /// This will give you a REPL which allows you to inspect the state of + /// the simulation (type help to see a list of commands) + /// + /// # Errors + /// Internal debugger errors e.g., reading or writing to stdin/stdout; + /// errors in Ixa are printed to stdout + fn schedule_debugger(&mut self, t: f64); +} + +impl ContextDebugExt for Context { + fn schedule_debugger(&mut self, t: f64) { + self.add_plan(t, |context| { + start_debugger(context).expect("Error in debugger"); + }); + } +} + +#[cfg(test)] +mod tests { + use crate::{Context, ContextPeopleExt}; + use std::{cell::RefCell, io::Write, rc::Rc}; + + use super::build_repl; + + #[derive(Clone)] + struct StdoutMock { + storage: Rc>>, + } + + impl StdoutMock { + fn new() -> Self { + StdoutMock { + storage: Rc::new(RefCell::new(Vec::new())), + } + } + fn into_inner(self) -> Vec { + Rc::try_unwrap(self.storage) + .expect("Multiple references to storage") + .into_inner() + } + fn into_string(self) -> String { + String::from_utf8(self.into_inner()).unwrap() + } + } + impl Write for StdoutMock { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.storage.borrow_mut().write(buf) + } + fn flush(&mut self) -> std::io::Result<()> { + self.storage.borrow_mut().flush() + } + } + + #[test] + fn test_cli_debugger_integration() { + assert_cmd::Command::cargo_bin("runner_test_debug") + .unwrap() + .args(["--debugger", "1.0"]) + .write_stdin("population\n") + .write_stdin("continue\n") + .assert() + .success(); + } + + #[test] + fn test_cli_debugger_population() { + let context = &mut Context::new(); + // Add 2 people + context.add_person(()).unwrap(); + context.add_person(()).unwrap(); + + let output = StdoutMock::new(); + let repl = build_repl(output.clone()); + let quits = repl.process_line("population\n", context).unwrap(); + assert!(!quits, "should not exit"); + + drop(repl); + assert!(output.into_string().contains('2')); + } + + #[test] + fn test_cli_continue() { + let context = &mut Context::new(); + let output = StdoutMock::new(); + let repl = build_repl(output.clone()); + let quits = repl.process_line("continue\n", context).unwrap(); + assert!(quits, "should exit"); + } + + #[test] + fn test_cli_next() { + let context = &mut Context::new(); + assert_eq!(context.remaining_plan_count(), 0); + let output = StdoutMock::new(); + let repl = build_repl(output.clone()); + let quits = repl.process_line("next 2\n", context).unwrap(); + assert!(quits, "should exit"); + assert_eq!( + context.remaining_plan_count(), + 1, + "should schedule a plan for the debugger to pause" + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 6f9663c..28ba27a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,5 +49,8 @@ pub use random::{ContextRandomExt, RngId}; pub mod report; pub use report::{ConfigReportOptions, ContextReportExt, Report}; + pub mod runner; pub use runner::{run_with_args, run_with_custom_args, BaseArgs}; + +pub mod debugger; diff --git a/src/plan.rs b/src/plan.rs index e27a173..7d9c70c 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -97,6 +97,11 @@ impl Queue { } } } + + #[doc(hidden)] + pub(crate) fn remaining_plan_count(&self) -> usize { + self.queue.len() + } } impl Default for Queue { diff --git a/src/runner.rs b/src/runner.rs index 57360fe..01f0b81 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -1,10 +1,10 @@ use std::path::{Path, PathBuf}; -use crate::context::Context; use crate::error::IxaError; use crate::global_properties::ContextGlobalPropertiesExt; use crate::random::ContextRandomExt; use crate::report::ContextReportExt; +use crate::{context::Context, debugger::ContextDebugExt}; use clap::{Args, Command, FromArgMatches as _}; /// Default cli arguments for ixa runner @@ -21,6 +21,10 @@ pub struct BaseArgs { /// Optional path for report output #[arg(short, long, default_value = "")] pub output_dir: String, + + /// Set a breakpoint at a given time and start the debugger. Defaults to t=0.0 + #[arg(short, long)] + pub debugger: Option>, } #[derive(Args)] @@ -42,7 +46,7 @@ fn create_ixa_cli() -> Command { /// # Errors /// Returns an error if argument parsing or the setup function fails #[allow(clippy::missing_errors_doc)] -pub fn run_with_custom_args(setup_fn: F) -> Result<(), Box> +pub fn run_with_custom_args(setup_fn: F) -> Result> where A: Args, F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, @@ -66,7 +70,7 @@ where /// # Errors /// Returns an error if argument parsing or the setup function fails #[allow(clippy::missing_errors_doc)] -pub fn run_with_args(setup_fn: F) -> Result<(), Box> +pub fn run_with_args(setup_fn: F) -> Result> where F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, { @@ -81,7 +85,7 @@ fn run_with_args_internal( args: BaseArgs, custom_args: Option, setup_fn: F, -) -> Result<(), Box> +) -> Result> where F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, { @@ -104,19 +108,23 @@ where context.init_random(args.random_seed); + // If a breakpoint is provided, stop at that time + if let Some(t) = args.debugger { + context.schedule_debugger(t.unwrap_or(0.0)); + } + // Run the provided Fn setup_fn(&mut context, args, custom_args)?; // Execute the context context.execute(); - Ok(()) + Ok(context) } #[cfg(test)] mod tests { use super::*; use crate::{define_global_property, define_rng}; - use assert_cmd::Command; use serde::Deserialize; #[derive(Args, Debug)] @@ -135,7 +143,7 @@ mod tests { fn test_cli_invocation_with_custom_args() { // Note this target is defined in the bin section of Cargo.toml // and the entry point is in tests/bin/runner_test_custom_args - Command::cargo_bin("runner_test_custom_args") + assert_cmd::Command::cargo_bin("runner_test_custom_args") .unwrap() .args(["--field", "42"]) .assert() @@ -155,6 +163,7 @@ mod tests { random_seed: 42, config: String::new(), output_dir: String::new(), + debugger: None, }; // Use a comparison context to verify the random seed was set @@ -183,6 +192,7 @@ mod tests { random_seed: 42, config: "tests/data/global_properties_runner.json".to_string(), output_dir: String::new(), + debugger: None, }; let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { let p3 = ctx.get_global_property_value(RunnerProperty).unwrap(); @@ -198,6 +208,7 @@ mod tests { random_seed: 42, config: String::new(), output_dir: "data".to_string(), + debugger: None, }; let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { let output_dir = &ctx.report_options().directory; @@ -213,6 +224,7 @@ mod tests { random_seed: 42, config: String::new(), output_dir: String::new(), + debugger: None, }; let custom = CustomArgs { field: 42 }; let result = run_with_args_internal(test_args, Some(custom), |_, _, c| { diff --git a/tests/bin/runner_test_debug.rs b/tests/bin/runner_test_debug.rs new file mode 100644 index 0000000..d060d6f --- /dev/null +++ b/tests/bin/runner_test_debug.rs @@ -0,0 +1,12 @@ +use ixa::runner::run_with_args; +use ixa::ContextPeopleExt; +fn main() { + run_with_args(|context, _args, _| { + context.add_person(()).unwrap(); + context.add_person(()).unwrap(); + context.add_person(()).unwrap(); + + Ok(()) + }) + .unwrap(); +}