diff --git a/explainability/src/musmcs_enumeration/marco/mod.rs b/explainability/src/musmcs_enumeration/marco/mod.rs index 2e0472d9..ff7d3d0d 100644 --- a/explainability/src/musmcs_enumeration/marco/mod.rs +++ b/explainability/src/musmcs_enumeration/marco/mod.rs @@ -61,3 +61,65 @@ trait SubsetSolver { fn get_internal_solver(&mut self) -> &mut Solver; } + +#[cfg(test)] +mod tests { + + use std::collections::BTreeSet; + + use aries::model::lang::expr::lt; + + use super::{Marco, MusMcsEnumerationConfig}; + + type Model = aries::model::Model<&'static str>; + type SimpleMarco = super::simple_marco::SimpleMarco<&'static str>; + + #[test] + fn test_simple_marco_simple() { + let mut model: Model = Model::new(); + let x0 = model.new_ivar(0, 10, "x0"); + let x1 = model.new_ivar(0, 10, "x1"); + let x2 = model.new_ivar(0, 10, "x2"); + let soft_constrs = vec![lt(x0, x1), lt(x1, x2), lt(x2, x0), lt(x0, x2)]; + + let mut simple_marco = SimpleMarco::new( + model, + soft_constrs, + MusMcsEnumerationConfig { + return_muses: true, + return_mcses: true, + }, + ); + let res = simple_marco.run(); + + let computed_muses = res.muses_reif_lits.unwrap().into_iter().collect::>(); + let computed_mcses = res.mcses_reif_lits.unwrap().into_iter().collect::>(); + + let expected_muses = BTreeSet::from_iter(vec![ + BTreeSet::from_iter(vec![ + res.soft_constrs_reifs.get_reif_lit(0, 0), + res.soft_constrs_reifs.get_reif_lit(1, 0), + res.soft_constrs_reifs.get_reif_lit(2, 0), + ]), + BTreeSet::from_iter(vec![ + res.soft_constrs_reifs.get_reif_lit(2, 0), + res.soft_constrs_reifs.get_reif_lit(3, 0), + ]), + ]); + + let expected_mcses = BTreeSet::from_iter(vec![ + BTreeSet::from_iter(vec![res.soft_constrs_reifs.get_reif_lit(2, 0)]), + BTreeSet::from_iter(vec![ + res.soft_constrs_reifs.get_reif_lit(0, 0), + res.soft_constrs_reifs.get_reif_lit(3, 0), + ]), + BTreeSet::from_iter(vec![ + res.soft_constrs_reifs.get_reif_lit(1, 0), + res.soft_constrs_reifs.get_reif_lit(3, 0), + ]), + ]); + + assert_eq!(computed_muses, expected_muses); + assert_eq!(computed_mcses, expected_mcses); + } +}