diff --git a/.github/workflows/scallopy.yml b/.github/workflows/scallopy.yml index 6be6a05..ed0e1bb 100644 --- a/.github/workflows/scallopy.yml +++ b/.github/workflows/scallopy.yml @@ -23,8 +23,6 @@ jobs: max-parallel: 5 matrix: python-version: - - "3.8" - - "3.9" - "3.10" steps: diff --git a/changelog.md b/changelog.md index 30c2eaf..1f096b0 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,14 @@ +# v0.2.4, Aug 30, 2024 + +- Rule tags can now be expressions with potential reference to local variables: `rel 1/n::head() = body(n)` +- Allowing for sparse gradient computation inside Scallopy to minimize memory footprint +- Allowing users to specify per-datapoint output mapping inside Scallopy +- Adding destructor syntax so that ADTs can be used in a more idiomatic way +- Unifying the behavior of integer overflow inside Scallop +- Multiple bugs fixed + +# v0.2.3, Jun 23, 2024 + # v0.2.2, Oct 25, 2023 - Adding `wmc_with_disjunctions` option for provenances that deal with boolean formulas for more accurate probability estimation diff --git a/core/Cargo.toml b/core/Cargo.toml index 8db50af..6175f49 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-core" -version = "0.2.2" +version = "0.2.4" authors = ["Ziyang Li "] edition = "2018" diff --git a/core/src/common/foreign_predicate.rs b/core/src/common/foreign_predicate.rs index 40d1292..9a4169c 100644 --- a/core/src/common/foreign_predicate.rs +++ b/core/src/common/foreign_predicate.rs @@ -40,29 +40,29 @@ impl Binding { } } -/// The identifier of a foreign predicate in a registry -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub struct ForeignPredicateIdentifier { - identifier: String, - types: Box<[ValueType]>, - binding_pattern: BindingPattern, -} - -impl std::fmt::Display for ForeignPredicateIdentifier { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "pred {}[{}]({})", - self.identifier, - self.binding_pattern, - self - .types - .iter() - .map(|t| format!("{}", t)) - .collect::>() - .join(", ") - )) - } -} +// /// The identifier of a foreign predicate in a registry +// #[derive(Clone, Debug, Hash, PartialEq, Eq)] +// pub struct ForeignPredicateIdentifier { +// identifier: String, +// types: Box<[ValueType]>, +// binding_pattern: BindingPattern, +// } + +// impl std::fmt::Display for ForeignPredicateIdentifier { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// f.write_fmt(format_args!( +// "pred {}[{}]({})", +// self.identifier, +// self.binding_pattern, +// self +// .types +// .iter() +// .map(|t| format!("{}", t)) +// .collect::>() +// .join(", ") +// )) +// } +// } /// A binding pattern for a predicate, e.g. bbf #[derive(Clone, Debug, Hash, PartialEq, Eq)] diff --git a/core/src/compiler/back/compile.rs b/core/src/compiler/back/compile.rs index a3d8225..8f8b8f6 100644 --- a/core/src/compiler/back/compile.rs +++ b/core/src/compiler/back/compile.rs @@ -17,7 +17,7 @@ impl Program { // Perform rule level optimizations for rule in &mut self.rules { // First propagate equality - optimizations::propagate_equality(rule); + optimizations::propagate_equality(rule, &self.predicate_registry); // Enter the loop of constant folding/propagation loop { diff --git a/core/src/compiler/back/optimizations/equality_propagation.rs b/core/src/compiler/back/optimizations/equality_propagation.rs index 792a503..aab4eb4 100644 --- a/core/src/compiler/back/optimizations/equality_propagation.rs +++ b/core/src/compiler/back/optimizations/equality_propagation.rs @@ -1,8 +1,10 @@ use std::collections::*; +use crate::common::foreign_predicate::*; + use super::super::*; -pub fn propagate_equality(rule: &mut Rule) { +pub fn propagate_equality(rule: &mut Rule, foreign_predicate_registry: &ForeignPredicateRegistry) { let mut substitutions = HashMap::<_, Variable>::new(); let mut ignore_literals = HashSet::new(); let mut cannot_substitute = HashSet::::new(); @@ -18,7 +20,7 @@ pub fn propagate_equality(rule: &mut Rule) { } // Find all the bounded variables by atom and assign - let bounded = bounded_by_atom_and_assign(rule); + let bounded = bounded_by_atom_and_assign(rule, foreign_predicate_registry); // Collect all substitutions for (i, literal) in rule.body_literals().enumerate() { @@ -136,14 +138,26 @@ pub fn propagate_equality(rule: &mut Rule) { attributes: rule.attributes.clone(), head: new_head, body: Conjunction { args: new_literals }, - } + }; } -fn bounded_by_atom_and_assign(rule: &Rule) -> HashSet { +fn bounded_by_atom_and_assign(rule: &Rule, foreign_predicate_registry: &ForeignPredicateRegistry) -> HashSet { let mut bounded = rule .body_literals() .flat_map(|l| match l { - Literal::Atom(a) => a.variable_args().cloned().collect::>(), + Literal::Atom(atom) => { + if let Some(fp) = foreign_predicate_registry.get(&atom.predicate) { + // If atom is on foreign predicate, only the variables that are free will be bounded + atom.args[fp.num_bounded()..fp.arity()] + .iter() + .filter_map(|term| term.as_variable()) + .cloned() + .collect::>() + } else { + // If atom is on a normal relation, all the variables will be bounded + atom.variable_args().cloned().collect::>() + } + } _ => vec![], }) .collect::>(); diff --git a/core/src/compiler/front/analysis.rs b/core/src/compiler/front/analysis.rs index 81dd2f4..57cf1da 100644 --- a/core/src/compiler/front/analysis.rs +++ b/core/src/compiler/front/analysis.rs @@ -18,6 +18,7 @@ pub struct Analysis { pub constant_decl_analysis: ConstantDeclAnalysis, pub adt_analysis: AlgebraicDataTypeAnalysis, pub head_relation_analysis: HeadRelationAnalysis, + pub tagged_rule_analysis: TaggedRuleAnalysis, pub type_inference: TypeInference, pub boundness_analysis: BoundnessAnalysis, pub demand_attr_analysis: DemandAttributeAnalysis, @@ -41,6 +42,7 @@ impl Analysis { constant_decl_analysis: ConstantDeclAnalysis::new(), adt_analysis: AlgebraicDataTypeAnalysis::new(), head_relation_analysis: HeadRelationAnalysis::new(predicate_registry), + tagged_rule_analysis: TaggedRuleAnalysis::new(), type_inference: TypeInference::new(function_registry, predicate_registry, aggregate_registry), boundness_analysis: BoundnessAnalysis::new(predicate_registry), demand_attr_analysis: DemandAttributeAnalysis::new(), @@ -78,12 +80,15 @@ impl Analysis { items.walk(&mut analyzers); } - pub fn post_analysis(&mut self) { + pub fn post_analysis(&mut self, foreign_predicate_registry: &mut ForeignPredicateRegistry) { self.head_relation_analysis.compute_errors(); self.type_inference.check_query_predicates(); self.type_inference.infer_types(); self.demand_attr_analysis.check_arity(&self.type_inference); self.boundness_analysis.check_boundness(&self.demand_attr_analysis); + self + .tagged_rule_analysis + .register_predicates(&self.type_inference, foreign_predicate_registry); } pub fn dump_errors(&mut self, error_ctx: &mut FrontCompileError) { @@ -98,5 +103,6 @@ impl Analysis { error_ctx.extend(&mut self.type_inference.errors); error_ctx.extend(&mut self.boundness_analysis.errors); error_ctx.extend(&mut self.demand_attr_analysis.errors); + error_ctx.extend(&mut self.tagged_rule_analysis.errors); } } diff --git a/core/src/compiler/front/analyzers/constant_decl.rs b/core/src/compiler/front/analyzers/constant_decl.rs index 2031294..3ae4fe3 100644 --- a/core/src/compiler/front/analyzers/constant_decl.rs +++ b/core/src/compiler/front/analyzers/constant_decl.rs @@ -216,7 +216,7 @@ impl NodeVisitor for ConstantDeclAnalysis { for v in vars { if self.variables.contains_key(v.variable_name()) { self.variable_use.insert(v.location().clone(), v.name().to_string()); - } else { + } else if !fact_decl.atom().iter_args().any(|arg| arg.is_destruct()) { self.errors.push(ConstantDeclError::UnknownConstantVariable { name: v.name().to_string(), loc: v.location().clone(), diff --git a/core/src/compiler/front/analyzers/mod.rs b/core/src/compiler/front/analyzers/mod.rs index d19caf4..17be824 100644 --- a/core/src/compiler/front/analyzers/mod.rs +++ b/core/src/compiler/front/analyzers/mod.rs @@ -10,6 +10,7 @@ pub mod input_files; pub mod invalid_constant; pub mod invalid_wildcard; pub mod output_files; +pub mod tagged_rule; pub mod type_inference; pub use aggregation::AggregationAnalysis; @@ -24,6 +25,7 @@ pub use input_files::InputFilesAnalysis; pub use invalid_constant::InvalidConstantAnalyzer; pub use invalid_wildcard::InvalidWildcardAnalyzer; pub use output_files::OutputFilesAnalysis; +pub use tagged_rule::TaggedRuleAnalysis; pub use type_inference::TypeInference; pub mod errors { diff --git a/core/src/compiler/front/analyzers/tagged_rule.rs b/core/src/compiler/front/analyzers/tagged_rule.rs new file mode 100644 index 0000000..c090a23 --- /dev/null +++ b/core/src/compiler/front/analyzers/tagged_rule.rs @@ -0,0 +1,185 @@ +use lazy_static::lazy_static; +use std::collections::*; + +use crate::common::expr; +use crate::common::foreign_predicate::*; +use crate::common::input_tag::*; +use crate::common::tuple::*; +use crate::common::unary_op; +use crate::common::value::*; +use crate::common::value_type::*; + +use crate::compiler::front::*; +use crate::runtime::env::*; + +lazy_static! { + pub static ref TAG_TYPE: Vec = { + use ValueType::*; + vec![F64, F32, Bool] + }; +} + +#[derive(Clone, Debug)] +pub struct TaggedRuleAnalysis { + pub to_add_tag_predicates: HashMap, + pub errors: Vec, +} + +impl TaggedRuleAnalysis { + pub fn new() -> Self { + Self { + to_add_tag_predicates: HashMap::new(), + errors: Vec::new(), + } + } + + pub fn add_tag_predicate( + &mut self, + rule_id: ast::NodeLocation, + name: String, + arg_name: String, + tag_loc: ast::NodeLocation, + ) { + let pred = ToAddTagPredicate::new(name, arg_name, tag_loc); + self.to_add_tag_predicates.insert(rule_id, pred); + } + + pub fn register_predicates( + &mut self, + type_inference: &super::TypeInference, + foreign_predicate_registry: &mut ForeignPredicateRegistry, + ) { + for (rule_id, tag_predicate) in self.to_add_tag_predicates.drain() { + if let Some(rule_variable_type) = type_inference.rule_variable_type.get(&rule_id) { + if let Some(var_ty) = rule_variable_type.get(&tag_predicate.arg_name) { + match get_target_tag_type(var_ty, &tag_predicate.tag_loc) { + Ok(target_tag_ty) => { + // This means that we have an okay tag that is type checked + // Create a foreign predicate and register it + let fp = TagPredicate::new(tag_predicate.name.clone(), target_tag_ty); + if let Err(err) = foreign_predicate_registry.register(fp) { + self.errors.push(FrontCompileErrorMessage::error().msg(err.to_string())); + } + } + Err(err) => { + self.errors.push(err); + } + } + } + } + } + } +} + +fn get_target_tag_type( + var_ty: &analyzers::type_inference::TypeSet, + loc: &ast::NodeLocation, +) -> Result { + // Top priority: if var_ty is a base type, directly check if it is among some expected type + if let Some(base_ty) = var_ty.get_base_type() { + if TAG_TYPE.contains(&base_ty) { + return Ok(base_ty); + } + } + + // Then we check if the value can be casted into certain types + for tag_ty in TAG_TYPE.iter() { + if var_ty.can_type_cast(tag_ty) { + return Ok(var_ty.to_default_value_type()); + } + } + + // If not, then + return Err( + FrontCompileErrorMessage::error() + .msg(format!( + "A value of type `{var_ty}` cannot be casted into a dynamic tag" + )) + .src(loc.clone()), + ); +} + +/// The information of a helper tag predicate +/// +/// Suppose we have a rule +/// ``` ignore +/// rel 1/p :: head() = body(p) +/// ``` +/// +/// This rule will be transformed into +/// ``` ignore +/// rel head() = body(p) and tag#head#1#var == 1 / p and tag#head#1(tag#head#1#var) +/// ``` +#[derive(Clone, Debug)] +pub struct ToAddTagPredicate { + /// The name of the predicate + pub name: String, + + /// The main tag expression + pub arg_name: String, + + /// Tag location + pub tag_loc: ast::NodeLocation, +} + +impl ToAddTagPredicate { + pub fn new(name: String, arg_name: String, tag_loc: ast::NodeLocation) -> Self { + Self { + name, + arg_name, + tag_loc, + } + } +} + +/// An actual predicate +#[derive(Clone, Debug)] +pub struct TagPredicate { + /// The name of he predicate + pub name: String, + + /// args + pub arg_ty: ValueType, +} + +impl TagPredicate { + pub fn new(name: String, arg_ty: ValueType) -> Self { + Self { name, arg_ty } + } +} + +impl ForeignPredicate for TagPredicate { + fn name(&self) -> String { + self.name.clone() + } + + fn arity(&self) -> usize { + 1 + } + + fn argument_type(&self, i: usize) -> ValueType { + assert_eq!(i, 0); + self.arg_ty.clone() + } + + fn num_bounded(&self) -> usize { + 1 + } + + fn evaluate_with_env(&self, env: &RuntimeEnvironment, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + // Result tuple + let tup = vec![]; + + // Create a type cast expression and evaluate it on the given values + let tuple = Tuple::from_values(bounded.iter().cloned()); + let cast_expr = expr::Expr::unary(unary_op::UnaryOp::TypeCast(ValueType::F64), expr::Expr::access(0)); + let maybe_computed_tag = env.eval(&cast_expr, &tuple); + + // Return the value + if let Some(Tuple::Value(Value::F64(f))) = maybe_computed_tag { + vec![(DynamicInputTag::Float(f), tup)] + } else { + vec![] + } + } +} diff --git a/core/src/compiler/front/analyzers/type_inference/local.rs b/core/src/compiler/front/analyzers/type_inference/local.rs index 2d803c0..eae0604 100644 --- a/core/src/compiler/front/analyzers/type_inference/local.rs +++ b/core/src/compiler/front/analyzers/type_inference/local.rs @@ -50,11 +50,14 @@ impl LocalTypeInferenceContext { pub fn unify_atom_arities( &self, predicate_registry: &PredicateTypeRegistry, + // ignore_relations: &HashSet, inferred_relation_types: &mut HashMap, Loc)>, ) -> Result<(), Error> { for (pred, arities) in &self.atom_arities { // Skip foreign predicates - if predicate_registry.contains_predicate(pred) { + if predicate_registry.contains_predicate(pred) + /* || ignore_relations.contains(pred) */ + { continue; } @@ -122,6 +125,7 @@ impl LocalTypeInferenceContext { predicate_type_registry: &PredicateTypeRegistry, aggregate_type_registry: &AggregateTypeRegistry, inferred_expr_types: &mut HashMap, + strict: bool, ) -> Result<(), Error> { for unif in &self.unifications { unif.unify( @@ -132,6 +136,7 @@ impl LocalTypeInferenceContext { predicate_type_registry, aggregate_type_registry, inferred_expr_types, + strict, )?; } Ok(()) @@ -407,7 +412,18 @@ impl NodeVisitor for LocalTypeInferenceContext { impl NodeVisitor for LocalTypeInferenceContext { fn visit(&mut self, n: &NewExpr) { - let unif = Unification::New( + let unif = Unification::Entity( + n.functor_name().to_string(), + n.iter_args().map(|a| a.location().clone()).collect(), + n.location().clone(), + ); + self.unifications.push(unif) + } +} + +impl NodeVisitor for LocalTypeInferenceContext { + fn visit(&mut self, n: &DestructExpr) { + let unif = Unification::Entity( n.functor_name().to_string(), n.iter_args().map(|a| a.location().clone()).collect(), n.location().clone(), diff --git a/core/src/compiler/front/analyzers/type_inference/operator_rules.rs b/core/src/compiler/front/analyzers/type_inference/operator_rules.rs index 170b137..39ba4ed 100644 --- a/core/src/compiler/front/analyzers/type_inference/operator_rules.rs +++ b/core/src/compiler/front/analyzers/type_inference/operator_rules.rs @@ -6,19 +6,19 @@ lazy_static! { pub static ref ADD_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ + (I32, I32, I32), // Prioritize (I8, I8, I8), (I16, I16, I16), - (I32, I32, I32), (I64, I64, I64), (I128, I128, I128), (ISize, ISize, ISize), + (U32, U32, U32), // Prioritize (U8, U8, U8), (U16, U16, U16), - (U32, U32, U32), (U64, U64, U64), (U128, U128, U128), (USize, USize, USize), - (F32, F32, F32), + (F32, F32, F32), // Prioritize (F64, F64, F64), (String, String, String), (DateTime, Duration, DateTime), @@ -29,22 +29,23 @@ lazy_static! { (F64, Tensor, Tensor), ] }; + pub static ref SUB_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ + (I32, I32, I32), // Prioritize (I8, I8, I8), (I16, I16, I16), - (I32, I32, I32), (I64, I64, I64), (I128, I128, I128), (ISize, ISize, ISize), + (U32, U32, U32), // Prioritize (U8, U8, U8), (U16, U16, U16), - (U32, U32, U32), (U64, U64, U64), (U128, U128, U128), (USize, USize, USize), - (F32, F32, F32), + (F32, F32, F32), // Prioritize (F64, F64, F64), (DateTime, Duration, DateTime), (DateTime, DateTime, Duration), @@ -54,22 +55,23 @@ lazy_static! { (F64, Tensor, Tensor), ] }; + pub static ref MULT_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ + (I32, I32, I32), // Prioritize (I8, I8, I8), (I16, I16, I16), - (I32, I32, I32), (I64, I64, I64), (I128, I128, I128), (ISize, ISize, ISize), + (U32, U32, U32), // Prioritize (U8, U8, U8), (U16, U16, U16), - (U32, U32, U32), (U64, U64, U64), (U128, U128, U128), (USize, USize, USize), - (F32, F32, F32), + (F32, F32, F32), // Prioritize (F64, F64, F64), (Duration, I32, Duration), (I32, Duration, Duration), @@ -78,61 +80,64 @@ lazy_static! { (F64, Tensor, Tensor), ] }; + pub static ref DIV_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ + (I32, I32, I32), // Prioritize (I8, I8, I8), (I16, I16, I16), - (I32, I32, I32), (I64, I64, I64), (I128, I128, I128), (ISize, ISize, ISize), + (U32, U32, U32), // Prioritize (U8, U8, U8), (U16, U16, U16), - (U32, U32, U32), (U64, U64, U64), (U128, U128, U128), (USize, USize, USize), - (F32, F32, F32), + (F32, F32, F32), // Prioritize (F64, F64, F64), (Duration, I32, Duration), ] }; + pub static ref MOD_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ + (I32, I32, I32), // Prioritize (I8, I8, I8), (I16, I16, I16), - (I32, I32, I32), (I64, I64, I64), (I128, I128, I128), (ISize, ISize, ISize), + (U32, U32, U32), // Prioritize (U8, U8, U8), (U16, U16, U16), - (U32, U32, U32), (U64, U64, U64), (U128, U128, U128), (USize, USize, USize), - (F32, F32, F32), + (F32, F32, F32), // Prioritize (F64, F64, F64), ] }; + pub static ref COMPARE_TYPING_RULES: Vec<(ValueType, ValueType)> = { use ValueType::*; vec![ + (I32, I32), // Prioritize (I8, I8), (I16, I16), - (I32, I32), (I64, I64), (I128, I128), (ISize, ISize), + (U32, U32), // Prioritize (U8, U8), (U16, U16), - (U32, U32), (U64, U64), (U128, U128), (USize, USize), - (F32, F32), + (F32, F32), // Prioritize (F64, F64), (Duration, Duration), (DateTime, DateTime), diff --git a/core/src/compiler/front/analyzers/type_inference/type_inference.rs b/core/src/compiler/front/analyzers/type_inference/type_inference.rs index f7d1489..7cae78d 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_inference.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_inference.rs @@ -304,6 +304,7 @@ impl TypeInference { &self.foreign_predicate_type_registry, &self.foreign_aggregate_type_registry, &mut inferred_expr_types, + false, )?; ctx.propagate_variable_types(&mut inferred_var_expr, &mut inferred_expr_types)?; ctx.propagate_relation_types( @@ -317,6 +318,25 @@ impl TypeInference { // Final step, iterate through each rule and their local context // and make sure everything is fine for ctx in &self.rule_local_contexts { + // Try unify one last time; this time using "strict" option + ctx.unify_expr_types( + &self.custom_types, + &self.constant_types, + &self.inferred_relation_types, + &self.foreign_function_type_registry, + &self.foreign_predicate_type_registry, + &self.foreign_aggregate_type_registry, + &mut inferred_expr_types, + true, + )?; + ctx.propagate_variable_types(&mut inferred_var_expr, &mut inferred_expr_types)?; + ctx.propagate_relation_types( + &inferred_relation_expr, + &inferred_expr_types, + &mut self.inferred_relation_types, + )?; + + // Perform type cast and constraint checks ctx.check_type_cast(&self.custom_types, &inferred_expr_types)?; ctx.check_constraint(&inferred_expr_types)?; diff --git a/core/src/compiler/front/analyzers/type_inference/type_set.rs b/core/src/compiler/front/analyzers/type_inference/type_set.rs index cc91800..f8e6152 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_set.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_set.rs @@ -217,6 +217,20 @@ impl TypeSet { } } + pub fn is_base_type(&self) -> bool { + match self { + Self::BaseType(_, _) => true, + _ => false, + } + } + + pub fn get_base_type(&self) -> Option { + match self { + Self::BaseType(vt, _) => Some(vt.clone()), + _ => None, + } + } + pub fn is_boolean(&self) -> bool { match self { Self::BaseType(b, _) => b.is_boolean(), diff --git a/core/src/compiler/front/analyzers/type_inference/unification.rs b/core/src/compiler/front/analyzers/type_inference/unification.rs index 4ef45e8..9030e05 100644 --- a/core/src/compiler/front/analyzers/type_inference/unification.rs +++ b/core/src/compiler/front/analyzers/type_inference/unification.rs @@ -73,7 +73,9 @@ pub enum Unification { }, /// C, ops*, new C(ops*) - New(String, Vec, Loc), + /// or + /// C, ops*, C(ops*) + Entity(String, Vec, Loc), } impl Unification { @@ -87,6 +89,7 @@ impl Unification { predicate_type_registry: &PredicateTypeRegistry, aggregate_type_registry: &AggregateTypeRegistry, inferred_expr_types: &mut HashMap, + strict: bool, ) -> Result<(), Error> { match self { Self::IthArgOfRelation(e, p, i) => { @@ -159,16 +162,16 @@ impl Unification { } } Self::Add(op1, op2, e) => { - unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &ADD_TYPING_RULES) + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &ADD_TYPING_RULES, strict) } Self::Sub(op1, op2, e) => { - unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &SUB_TYPING_RULES) + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &SUB_TYPING_RULES, strict) } Self::Mult(op1, op2, e) => { - unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &MULT_TYPING_RULES) + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &MULT_TYPING_RULES, strict) } Self::Div(op1, op2, e) => { - unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &DIV_TYPING_RULES) + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &DIV_TYPING_RULES, strict) } Self::Mod(op1, op2, e) => { let e_ty = inferred_expr_types @@ -225,7 +228,7 @@ impl Unification { Ok(()) } Self::LtLeqGtGeq(op1, op2, e) => { - unify_comparison_expression(op1, op2, e, inferred_expr_types, &COMPARE_TYPING_RULES) + unify_comparison_expression(op1, op2, e, inferred_expr_types, &COMPARE_TYPING_RULES, strict) } Self::PosNeg(op1, e) => { let e_ty = inferred_expr_types @@ -537,7 +540,7 @@ impl Unification { ) } } - Self::New(functor, args, e) => { + Self::Entity(functor, args, e) => { let adt_variant_relation_name = format!("adt#{functor}"); // cond should be boolean @@ -578,7 +581,7 @@ impl Unification { enum AppliedRules { None, One(T), - Multiple, + Multiple(T, Vec), } impl AppliedRules { @@ -589,8 +592,8 @@ impl AppliedRules { fn add(self, rule: T) -> Self { match self { Self::None => Self::One(rule), - Self::One(_) => Self::Multiple, - Self::Multiple => Self::Multiple, + Self::One(f) => Self::Multiple(f, vec![]), + Self::Multiple(f, r) => Self::Multiple(f, r.into_iter().chain(std::iter::once(rule)).collect()), } } } @@ -605,6 +608,7 @@ fn unify_polymorphic_binary_expression( e: &Loc, inferred_expr_types: &mut HashMap, rules: &[(ValueType, ValueType, ValueType)], + strict: bool, ) -> Result<(), Error> { // First get the already inferred types of op1, op2, and e let op1_ty = unify_any(op1, inferred_expr_types).map_err(|e| e.into())?; @@ -632,10 +636,18 @@ fn unify_polymorphic_binary_expression( unify_ty(e, TypeSet::BaseType(te, e.clone()), inferred_expr_types).map_err(|e| e.into())?; Ok(()) } - AppliedRules::Multiple => { - // If ther are multiple rules that can be applied, we are not sure about the exact types, - // but the type inference is still successful - Ok(()) + AppliedRules::Multiple((t1, t2, te), _) => { + if strict { + // If there is exactly one rule that can be applied, then unify them with the exact types + unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + unify_ty(e, TypeSet::BaseType(te, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + Ok(()) + } else { + // If ther are multiple rules that can be applied, we are not sure about the exact types, + // but the type inference is still successful + Ok(()) + } } } } @@ -646,6 +658,7 @@ fn unify_comparison_expression( e: &Loc, inferred_expr_types: &mut HashMap, rules: &[(ValueType, ValueType)], + strict: bool, ) -> Result<(), Error> { // The result should be a boolean let e_ty = unify_boolean(e, inferred_expr_types).map_err(|e| e.into())?; @@ -674,10 +687,17 @@ fn unify_comparison_expression( unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types).map_err(|e| e.into())?; Ok(()) } - AppliedRules::Multiple => { - // If ther are multiple rules that can be applied, we are not sure about the exact types, - // but the type inference is still successful - Ok(()) + AppliedRules::Multiple((t1, t2), _) => { + if strict { + // Under strict mode, directly unify with the first matched rule + unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + Ok(()) + } else { + // If ther are multiple rules that can be applied, we are not sure about the exact types, + // but the type inference is still successful + Ok(()) + } } } } diff --git a/core/src/compiler/front/ast/constant.rs b/core/src/compiler/front/ast/constant.rs index 138d2f6..3e0dc4d 100644 --- a/core/src/compiler/front/ast/constant.rs +++ b/core/src/compiler/front/ast/constant.rs @@ -1,26 +1,8 @@ -use crate::common::input_tag::DynamicInputTag; use crate::common::value::Value; use crate::common::value_type::ValueType; use super::*; -/// A tag associated with a fact -#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] -#[doc(hidden)] -pub struct _Tag { - pub tag: DynamicInputTag, -} - -impl Tag { - pub fn none() -> Self { - Self::new(DynamicInputTag::None) - } - - pub fn is_some(&self) -> bool { - self.tag().is_some() - } -} - #[derive(Clone, Debug, PartialEq, Hash, Serialize, AstNode)] #[doc(hidden)] pub struct _IntLiteral { diff --git a/core/src/compiler/front/ast/expr.rs b/core/src/compiler/front/ast/expr.rs index dfec6b2..41d8e4f 100644 --- a/core/src/compiler/front/ast/expr.rs +++ b/core/src/compiler/front/ast/expr.rs @@ -1,3 +1,5 @@ +use crate::common; + use super::*; #[derive(Clone, Debug, PartialEq, Serialize, AstNode)] @@ -10,6 +12,7 @@ pub enum Expr { IfThenElse(IfThenElseExpr), Call(CallExpr), New(NewExpr), + Destruct(DestructExpr), } impl Expr { @@ -31,6 +34,7 @@ impl Expr { Self::IfThenElse(i) => i.cond().has_variable() || i.then_br().has_variable() || i.else_br().has_variable(), Self::Call(c) => c.iter_args().any(|a| a.has_variable()), Self::New(n) => n.iter_args().any(|a| a.has_variable()), + Self::Destruct(n) => n.iter_args().any(|a| a.has_variable()), } } @@ -69,6 +73,11 @@ impl Expr { a.collect_used_variables_helper(vars); } } + Self::Destruct(n) => { + for a in n.iter_args() { + a.collect_used_variables_helper(vars); + } + } } } @@ -100,6 +109,14 @@ impl Expr { } None } + Expr::Destruct(n) => { + for arg in n.iter_args() { + if let Some(loc) = arg.get_first_variable_location() { + return Some(loc); + } + } + None + } } } @@ -290,6 +307,17 @@ impl UnaryOp { pub fn is_pos_neg(&self) -> bool { self.is_pos() || self.is_neg() } + + /// Cast the AST unary operator to Common unary operator. + /// The `TypeCast` operation will be discarded. + pub fn to_common_unary_op(&self) -> Option { + match self._node { + _UnaryOp::Neg => Some(common::unary_op::UnaryOp::Neg), + _UnaryOp::Pos => Some(common::unary_op::UnaryOp::Pos), + _UnaryOp::Not => Some(common::unary_op::UnaryOp::Not), + _UnaryOp::TypeCast(_) => None, + } + } } #[derive(Clone, Debug, PartialEq, Serialize, AstNode)] @@ -339,3 +367,16 @@ impl NewExpr { self.functor().name() } } + +#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] +#[doc(hidden)] +pub struct _DestructExpr { + pub functor: Identifier, + pub args: Vec, +} + +impl DestructExpr { + pub fn functor_name(&self) -> &str { + self.functor().name() + } +} diff --git a/core/src/compiler/front/ast/mod.rs b/core/src/compiler/front/ast/mod.rs index 07d564a..2f10073 100644 --- a/core/src/compiler/front/ast/mod.rs +++ b/core/src/compiler/front/ast/mod.rs @@ -10,6 +10,7 @@ mod item; mod query; mod relation_decl; mod rule; +mod tag; mod type_decl; mod types; mod utils; @@ -26,6 +27,7 @@ pub use item::*; pub use query::*; pub use relation_decl::*; pub use rule::*; +pub use tag::*; pub use type_decl::*; pub use types::*; pub use utils::*; diff --git a/core/src/compiler/front/ast/relation_decl.rs b/core/src/compiler/front/ast/relation_decl.rs index 260227f..e0b1d85 100644 --- a/core/src/compiler/front/ast/relation_decl.rs +++ b/core/src/compiler/front/ast/relation_decl.rs @@ -15,7 +15,7 @@ impl ConstantTuple { #[derive(Clone, Debug, PartialEq, Serialize, AstNode)] #[doc(hidden)] pub struct _ConstantSetTuple { - pub tag: Tag, + pub tag: Option, pub tuple: ConstantTuple, } @@ -58,7 +58,7 @@ impl ConstantSetDecl { #[doc(hidden)] pub struct _FactDecl { pub attrs: Attributes, - pub tag: Tag, + pub tag: Option, pub atom: Atom, } @@ -78,17 +78,13 @@ impl FactDecl { pub fn iter_constants(&self) -> impl Iterator { self.iter_args().filter_map(|expr| expr.as_constant()) } - - pub fn has_tag(&self) -> bool { - self.tag().is_some() - } } #[derive(Clone, Debug, PartialEq, Serialize, AstNode)] #[doc(hidden)] pub struct _RuleDecl { pub attrs: Attributes, - pub tag: Tag, + pub tag: Option, pub rule: Rule, } diff --git a/core/src/compiler/front/ast/rule.rs b/core/src/compiler/front/ast/rule.rs index a70f8fc..198146d 100644 --- a/core/src/compiler/front/ast/rule.rs +++ b/core/src/compiler/front/ast/rule.rs @@ -11,7 +11,7 @@ impl Into> for Rule { fn into(self) -> Vec { vec![Item::RelationDecl(RelationDecl::Rule(RuleDecl::new( Attributes::new(), - Tag::none(), + None, self, )))] } diff --git a/core/src/compiler/front/ast/tag.rs b/core/src/compiler/front/ast/tag.rs new file mode 100644 index 0000000..429317a --- /dev/null +++ b/core/src/compiler/front/ast/tag.rs @@ -0,0 +1,157 @@ +use std::collections::*; + +use crate::common::expr as common_expr; +use crate::common::input_tag::DynamicInputTag; + +use super::*; + +/// A tag associated with a fact +#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] +#[doc(hidden)] +pub enum Tag { + Constant(ConstantTag), + Expr(ExprTag), +} + +impl Tag { + pub fn none() -> Self { + Self::Constant(ConstantTag::none()) + } + + pub fn is_some(&self) -> bool { + match self.as_constant() { + Some(c) => c.is_some(), + None => false, + } + } + + pub fn used_variables(&self) -> BTreeSet { + match self { + Self::Constant(_) => BTreeSet::new(), + Self::Expr(e) => e.used_variables(), + } + } + + pub fn to_base_expr(&self, vars: &HashMap) -> Option { + match self { + Self::Constant(c) => c.to_base_expr(vars), + Self::Expr(e) => e.to_base_expr(vars), + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] +#[doc(hidden)] +pub struct _ConstantTag { + pub tag: DynamicInputTag, +} + +impl ConstantTag { + pub fn none() -> Self { + Self::new(DynamicInputTag::None) + } + + pub fn float(f: f64) -> Self { + Self::new(DynamicInputTag::Float(f)) + } + + pub fn boolean(b: bool) -> Self { + Self::new(DynamicInputTag::Bool(b)) + } + + pub fn is_some(&self) -> bool { + self.tag().is_some() + } + + pub fn to_base_expr(&self, _vars: &HashMap) -> Option { + match self.tag() { + DynamicInputTag::Bool(b) => Some(common_expr::Expr::constant(*b)), + DynamicInputTag::Float(f) => Some(common_expr::Expr::constant(*f)), + _ => None, + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] +#[doc(hidden)] +pub enum ExprTag { + Variable(VariableTag), + Binary(BinaryExprTag), + Unary(UnaryExprTag), +} + +impl ExprTag { + pub fn used_variables(&self) -> BTreeSet { + match self { + Self::Variable(v) => v.used_variables(), + Self::Binary(b) => b.used_variables(), + Self::Unary(u) => u.used_variables(), + } + } + + pub fn to_base_expr(&self, vars: &HashMap) -> Option { + match self { + Self::Variable(v) => v.to_base_expr(vars), + Self::Binary(b) => b.to_base_expr(vars), + Self::Unary(u) => u.to_base_expr(vars), + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] +#[doc(hidden)] +pub struct _VariableTag { + pub variable: Identifier, +} + +impl VariableTag { + pub fn used_variables(&self) -> BTreeSet { + std::iter::once(self.variable().name().clone()).collect() + } + + pub fn to_base_expr(&self, vars: &HashMap) -> Option { + let id = vars.get(self.variable().name())?; + Some(common_expr::Expr::access(*id)) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] +#[doc(hidden)] +pub struct _BinaryExprTag { + pub op: BinaryOp, + pub op1: Box, + pub op2: Box, +} + +impl BinaryExprTag { + pub fn used_variables(&self) -> BTreeSet { + let left = self.op1().used_variables().into_iter(); + let right = self.op2().used_variables().into_iter(); + left.chain(right).collect() + } + + pub fn to_base_expr(&self, vars: &HashMap) -> Option { + let op1 = self.op1().to_base_expr(vars)?; + let op2 = self.op2().to_base_expr(vars)?; + Some(common_expr::Expr::binary(self.op().op().clone(), op1, op2)) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, AstNode)] +#[doc(hidden)] +pub struct _UnaryExprTag { + pub op: UnaryOp, + pub op1: Box, +} + +impl UnaryExprTag { + pub fn used_variables(&self) -> BTreeSet { + self.op1().used_variables() + } + + pub fn to_base_expr(&self, vars: &HashMap) -> Option { + let op = self.op().to_common_unary_op()?; + let op1 = self.op1().to_base_expr(vars)?; + Some(common_expr::Expr::unary(op, op1)) + } +} diff --git a/core/src/compiler/front/ast/utils.rs b/core/src/compiler/front/ast/utils.rs index ffe519b..19a3f20 100644 --- a/core/src/compiler/front/ast/utils.rs +++ b/core/src/compiler/front/ast/utils.rs @@ -192,6 +192,7 @@ pub trait AstNode: Clone { /// Obtain a mutable location of the AstNode fn location_mut(&mut self) -> &mut NodeLocation; + /// Clone with a given location fn clone_with_loc(&self, loc: NodeLocation) -> Self; } @@ -235,10 +236,11 @@ impl AstNode for AstNodeWrapper { } } +#[allow(unused)] pub trait NodeVisitor { - fn visit(&mut self, node: &N); + fn visit(&mut self, node: &N) {} - fn visit_mut(&mut self, node: &mut N); + fn visit_mut(&mut self, node: &mut N) {} } #[allow(unused)] @@ -248,6 +250,30 @@ impl NodeVisitor for U { default fn visit_mut(&mut self, node: &mut V) {} } +impl NodeVisitor for Vec { + fn visit(&mut self, node: &V) { + for elem in self { + elem.visit(node); + } + } + + fn visit_mut(&mut self, node: &mut V) { + for elem in self { + elem.visit_mut(node); + } + } +} + +impl NodeVisitor for Box { + fn visit(&mut self, node: &V) { + (&**self).visit(node) + } + + fn visit_mut(&mut self, node: &mut V) { + (&mut **self).visit_mut(node) + } +} + macro_rules! impl_node_visitor_tuple { ( $($id:ident,)* ) => { impl NodeVisitor for ($(&mut $id,)*) { diff --git a/core/src/compiler/front/compile.rs b/core/src/compiler/front/compile.rs index 212f7ff..a3a92f5 100644 --- a/core/src/compiler/front/compile.rs +++ b/core/src/compiler/front/compile.rs @@ -297,10 +297,18 @@ impl FrontContext { }); ast.walk_mut(&mut dup_ctx.node_id_annotator); + // Clone the foreign registries (TODO: add other registries) + let mut cloned_fp_registry = dup_ctx.foreign_predicate_registry.clone(); + // Front analysis dup_ctx.analysis.modify(|analysis| { + // First do main traversal of analyzers analysis.process_items(&ast); - analysis.post_analysis(); + + // Do the post-traversal analysis (TODO: add other registries when they are required) + analysis.post_analysis(&mut cloned_fp_registry); + + // Dump the errors to error context analysis.dump_errors(&mut error_ctx); }); if error_ctx.has_error() { @@ -312,8 +320,13 @@ impl FrontContext { error_ctx.report_warnings(); } - // Update self if nothing goes wrong + // update the AST items dup_ctx.items.extend(ast); + + // update the registries + dup_ctx.foreign_predicate_registry = cloned_fp_registry; + + // Update self if nothing goes wrong *self = dup_ctx; // Pull out the last ast items diff --git a/core/src/compiler/front/f2b/f2b.rs b/core/src/compiler/front/f2b/f2b.rs index e050335..7b964ff 100644 --- a/core/src/compiler/front/f2b/f2b.rs +++ b/core/src/compiler/front/f2b/f2b.rs @@ -1,5 +1,6 @@ use std::collections::*; +use crate::common::input_tag::DynamicInputTag; use crate::common::output_option::OutputOption; use crate::common::value_type::ValueType; use crate::compiler::back; @@ -120,8 +121,9 @@ impl FrontContext { c.as_constant().unwrap().to_value(t) }) .collect(); + let tag = to_constant_dyn_input_tag(tuple.tag()); back::Fact { - tag: tuple.tag().tag().clone(), + tag, predicate: pred.clone(), args, } @@ -133,8 +135,9 @@ impl FrontContext { let pred = f.predicate_name(); let tys = self.relation_arg_types(&pred).unwrap(); let args = f.iter_constants().zip(tys.iter()).map(|(c, t)| c.to_value(t)).collect(); + let tag = to_constant_dyn_input_tag(&f.tag().clone().and_then(|e| e.as_constant().cloned())); let back_fact = back::Fact { - tag: f.tag().tag().clone(), + tag: tag, predicate: pred.clone(), args, }; @@ -165,8 +168,9 @@ impl FrontContext { c.as_constant().unwrap().to_value(t) }) .collect(); + let tag = to_constant_dyn_input_tag(tuple.tag()); back::Fact { - tag: tuple.tag().tag().clone(), + tag, predicate: pred.clone(), args, } @@ -611,3 +615,16 @@ impl FrontContext { self.back_vars_with_types(var_names, var_tys) } } + +fn to_constant_dyn_input_tag(constant: &Option) -> DynamicInputTag { + if let Some(constant) = constant { + match constant { + front::Constant::Boolean(b) => DynamicInputTag::Bool(*b.value()), + front::Constant::Float(f) => DynamicInputTag::Float(*f.float()), + front::Constant::Integer(i) => DynamicInputTag::Float(*i.int() as f64), + _ => DynamicInputTag::None, + } + } else { + DynamicInputTag::None + } +} diff --git a/core/src/compiler/front/grammar.lalrpop b/core/src/compiler/front/grammar.lalrpop index 6a3250b..6ef7e17 100644 --- a/core/src/compiler/front/grammar.lalrpop +++ b/core/src/compiler/front/grammar.lalrpop @@ -1,7 +1,6 @@ use std::str::FromStr; use super::ast::*; -use crate::common::input_tag::DynamicInputTag; grammar; @@ -324,6 +323,72 @@ _ConstDecl: _ConstDecl = { } } +/// ======================= /// +/// ========= Tag ========= /// +/// ======================= /// + +Tag: Expr = ExprTagTop; + +ExprTagTop: Expr = { + AndOrExprTag, +} + +_AndOrBinaryExprTag: _BinaryExpr = { + => { + _BinaryExpr::new(op, op1, op2) + } +} + +AndOrExprTag: Expr = { + Spanned<_AndOrBinaryExprTag> => Expr::binary(<>), + AddSubExprTag, +} + +_AddSubBinaryExprTag: _BinaryExpr = { + => { + _BinaryExpr::new(op, op1, op2) + } +} + +AddSubExprTag: Expr = { + Spanned<_AddSubBinaryExprTag> => Expr::binary(<>), + MulDivModExprTag, +} + +_MulDivModBinaryExprTag: _BinaryExpr = { + => { + _BinaryExpr::new(op, op1, op2) + } +} + +MulDivModExprTag: Expr = { + Spanned<_MulDivModBinaryExprTag> => Expr::binary(<>), + UnaryExprTag, +} + +_UnaryExprTag: _UnaryExpr = { + => _UnaryExpr::new(op, op1), + // => _UnaryExpr::new(op, op1), +} + +UnaryExprTag: Expr = { + Spanned<_UnaryExprTag> => Expr::unary(<>), + UnitExprTag, +} + +ComplexExprTag: Expr = { + Spanned<_AndOrBinaryExprTag> => Expr::binary(<>), + Spanned<_AddSubBinaryExprTag> => Expr::binary(<>), + Spanned<_MulDivModBinaryExprTag> => Expr::binary(<>), + Spanned<_UnaryExprTag> => Expr::unary(<>), +} + +UnitExprTag: Expr = { + Constant => Expr::constant(<>), + Identifier => Expr::variable(Variable::new(<>)), + "(" ")" => t, +} + /// ======================================== /// /// ========= Relation Declaration ========= /// /// ======================================== /// @@ -342,12 +407,6 @@ RelationDecl: RelationDecl = { ReduceRuleDecl => RelationDecl::ReduceRule(<>), } -Tag = Spanned<_Tag>; -_Tag: _Tag = { - Float => _Tag::new(DynamicInputTag::Float(<>)), - Bool => _Tag::new(DynamicInputTag::Bool(<>)), -} - Constant: Constant = { IntLiteral => Constant::Integer(<>), FloatLiteral => Constant::Float(<>), @@ -372,11 +431,11 @@ _ConstantTuple: _ConstantTuple = { ConstantSetTuple = Spanned<_ConstantSetTuple>; _ConstantSetTuple: _ConstantSetTuple = { - "::" => { - _ConstantSetTuple::new(tag, tuple) + "::" => { + _ConstantSetTuple::new(Some(tag), tuple) }, => { - _ConstantSetTuple::new(Tag::none(), tuple) + _ConstantSetTuple::new(None, tuple) } } @@ -395,8 +454,8 @@ _ConstantSetDecl: _ConstantSetDecl = { FactDecl = Spanned<_FactDecl>; _FactDecl: _FactDecl = { - RelationKeyword "::" => _FactDecl::new(attrs, tag, a), - RelationKeyword => _FactDecl::new(attrs, Tag::none(), a), + RelationKeyword "::" => _FactDecl::new(attrs, Some(tag), a), + RelationKeyword => _FactDecl::new(attrs, None, a), } Wildcard = Spanned<_Wildcard>; @@ -407,10 +466,10 @@ _Variable: _Variable = Identifier => _Variable::new(<>); Atom = Spanned<_Atom>; _Atom: _Atom = { - "(" > ")" => { + "(" > ")" => { _Atom::new(predicate, vec![], args) }, - "(" > ")" => { + "(" > ")" => { let (predicate, type_arg_ids) = n; let type_args = type_arg_ids.into_iter().map(Type::from).collect(); _Atom::new(predicate, type_args, args) @@ -615,6 +674,18 @@ _VariableBinding: _VariableBinding = { "(" ":" ")" => _VariableBinding::new(name, Some(ty)), } +PossiblyDestructExpr: Expr = { + Expr, + => Expr::Destruct(e), +} + +DestructExpr = Spanned<_DestructExpr>; +_DestructExpr: _DestructExpr = { + "(" > ")" => { + _DestructExpr::new(functor, args) + } +} + Expr: Expr = IfThenElseExpr; IfThenElseExpr: Expr = { @@ -786,10 +857,10 @@ _Rule: _Rule = { RuleDecl = Spanned<_RuleDecl>; _RuleDecl: _RuleDecl = { RelationKeyword "::" => { - _RuleDecl::new(a, tag, r) + _RuleDecl::new(a, Some(tag), r) }, RelationKeyword => { - _RuleDecl::new(a, Tag::none(), r) + _RuleDecl::new(a, None, r) }, } diff --git a/core/src/compiler/front/pretty.rs b/core/src/compiler/front/pretty.rs index c20f6f7..9e87cdc 100644 --- a/core/src/compiler/front/pretty.rs +++ b/core/src/compiler/front/pretty.rs @@ -305,6 +305,49 @@ impl Display for Identifier { } } +impl Display for Tag { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + Tag::Constant(c) => c.fmt(f), + Tag::Expr(e) => e.fmt(f), + } + } +} + +impl Display for ConstantTag { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + self.tag().fmt(f) + } +} + +impl Display for ExprTag { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + Self::Variable(v) => v.fmt(f), + Self::Binary(b) => b.fmt(f), + Self::Unary(u) => u.fmt(f), + } + } +} + +impl Display for VariableTag { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + self.variable().fmt(f) + } +} + +impl Display for BinaryExprTag { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.write_fmt(format_args!("({} {} {})", self.op1(), self.op(), self.op2())) + } +} + +impl Display for UnaryExprTag { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.write_fmt(format_args!("({}{})", self.op(), self.op1())) + } +} + impl Display for ConstantSetDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for attr in self.attrs() { @@ -328,8 +371,14 @@ impl Display for FactDecl { for attr in self.attrs() { f.write_fmt(format_args!("{} ", attr))?; } + let tag = if let Some(tag) = self.tag() { + format!("{}::", tag) + } else { + "".to_string() + }; f.write_fmt(format_args!( - "rel {}({})", + "rel {}{}({})", + tag, self.predicate_name(), self .iter_args() @@ -345,7 +394,11 @@ impl Display for RuleDecl { for attr in self.attrs() { f.write_fmt(format_args!("{} ", attr))?; } - f.write_fmt(format_args!("rel {}", self.rule())) + if let Some(tag) = self.tag() { + f.write_fmt(format_args!("rel {}::{}", tag, self.rule())) + } else { + f.write_fmt(format_args!("rel {}", self.rule())) + } } } @@ -382,8 +435,8 @@ impl std::fmt::Display for Atom { impl Display for ConstantSetTuple { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - if self.tag().is_some() { - f.write_fmt(format_args!("{}::", self.tag().tag()))?; + if let Some(tag) = self.tag() { + f.write_fmt(format_args!("{}::", tag))?; } f.write_fmt(format_args!( "({})", @@ -749,6 +802,11 @@ impl std::fmt::Display for Expr { n.functor(), n.iter_args().map(|a| format!("{}", a)).collect::>().join(", ") )), + Self::Destruct(n) => f.write_fmt(format_args!( + "{}({})", + n.functor(), + n.iter_args().map(|a| format!("{}", a)).collect::>().join(", ") + )), } } } diff --git a/core/src/compiler/front/transform.rs b/core/src/compiler/front/transform.rs index 602265f..a06af08 100644 --- a/core/src/compiler/front/transform.rs +++ b/core/src/compiler/front/transform.rs @@ -1,18 +1,63 @@ +// use std::collections::*; +// use petgraph::graph::{Graph, NodeIndex}; + +// use crate::common::foreign_function::*; +// use crate::common::foreign_predicate::*; + use super::transformations::*; use super::*; +/// A transformation pass +/// +/// The lifetime constraint `'a` is based on the existing analysis results +pub trait Transformation<'a> { + /// The dependent transformation passes + fn dependencies(&self) -> Vec<&'static str> { + vec![] + } + + /// After running this pass, what transformation is invalidated and need to be rerun + fn invalidates(&self) -> Vec<&'static str> { + vec![] + } + + /// After running this pass, does an item need to be removed + #[allow(unused)] + fn post_walking_removes_item(&self, item: &Item) -> bool { + false + } + + /// After running this pass, is there newly generated items + fn post_walking_generated_items(&mut self) -> Vec { + vec![] + } +} + +pub trait TransformationName { + fn name() -> &'static str { + std::any::type_name::() + } +} + +impl<'a, T> TransformationName for T where T: Transformation<'a> {} + pub fn apply_transformations(ast: &mut Vec, analysis: &mut Analysis) { + // let transf_passes = Transformations::std(analysis); + // transf_passes.run(ast); let mut transform_adt = TransformAlgebraicDataType::new(&mut analysis.adt_analysis); let mut transform_const_var_to_const = TransformConstVarToConst::new(&analysis.constant_decl_analysis); let mut transform_atomic_query = TransformAtomicQuery::new(); let mut transform_conjunctive_head = TransformConjunctiveHead::new(); - let mut transform_tagged_rule = TransformTaggedRule::new(); + let mut desugar_reduce_rule = DesugarReduceRule::new(&analysis.type_inference.foreign_aggregate_type_registry); + let mut transform_tagged_rule = TransformTaggedRule::new( + /* &mut analysis.type_inference, */ &mut analysis.tagged_rule_analysis, + ); let mut transform_non_const_fact = TransformNonConstantFactToRule; let mut desugar_arg_type_adornment = DesugarArgTypeAdornment::new(); + let mut desugar_destruct = DesugarDestruct::new(); let mut desugar_case_is = DesugarCaseIs::new(); let mut desugar_forall_exists = DesugarForallExists::new(); let mut desugar_range = DesugarRange::new(); - let mut desugar_reduce_rule = DesugarReduceRule::new(&analysis.type_inference.foreign_aggregate_type_registry); let mut forall_to_not_exists = TransformForall; let mut implies_to_disjunction = TransformImplies; let mut visitors = ( @@ -23,6 +68,7 @@ pub fn apply_transformations(ast: &mut Vec, analysis: &mut Analysis) { &mut transform_tagged_rule, &mut transform_non_const_fact, &mut desugar_arg_type_adornment, + &mut desugar_destruct, &mut desugar_case_is, &mut desugar_forall_exists, &mut desugar_range, @@ -41,11 +87,169 @@ pub fn apply_transformations(ast: &mut Vec, analysis: &mut Analysis) { new_items.extend(transform_const_var_to_const.generate_items()); new_items.extend(transform_atomic_query.drain_items()); new_items.extend(transform_conjunctive_head.generate_items()); - new_items.extend(transform_tagged_rule.drain_items()); // Some of the transformations need to be applied to new items as well - new_items.walk_mut(&mut transform_const_var_to_const); + let mut transform_const_var_to_const_2 = TransformConstVarToConst2::new(&analysis.constant_decl_analysis); + new_items.walk_mut(&mut transform_const_var_to_const_2); // Extend the ast to incorporate these new items - ast.extend(new_items) + ast.extend(new_items); } + +// pub struct DynTransformation<'a> { +// transf: Box + 'a> +// } + +// impl<'a> Transformation<'a> for DynTransformation<'a> { +// fn dependencies(&self) -> Vec<&'static str> { +// self.transf.dependencies() +// } + +// fn invalidates(&self) -> Vec<&'static str> { +// self.transf.invalidates() +// } + +// fn post_walking_removes_item(&self, item: &Item) -> bool { +// self.transf.post_walking_removes_item(item) +// } + +// /// After running this pass, is there newly generated items +// fn post_walking_generated_items(&mut self) -> Vec { +// self.transf.post_walking_generated_items() +// } +// } + +// impl<'a, V> NodeVisitor for DynTransformation<'a> { +// fn visit(&mut self, node: &V) { +// (&mut *self.transf).visit(node) +// } + +// fn visit_mut(&mut self, node: &mut V) { +// (&mut *self.transf).visit_mut(node) +// } +// } + +// /// A manager of all the transformation passes +// pub struct Transformations<'a> { +// transformations: Graph, ()>, +// transformation_ids: HashMap<&'static str, NodeIndex>, +// } + +// impl<'a> Transformations<'a> { +// pub fn empty() -> Self { +// Self { +// transformations: Graph::new(), +// transformation_ids: HashMap::new(), +// } +// } + +// pub fn std(analysis: &'a mut Analysis) -> Self { +// let mut passes = Self::empty(); + +// passes.add(TransformAlgebraicDataType::new(&mut analysis.adt_analysis)); +// passes.add(TransformConstVarToConst::new(&analysis.constant_decl_analysis)); +// passes.add(TransformAtomicQuery::new()); +// passes.add(TransformConjunctiveHead::new()); +// passes.add(TransformTaggedRule::new()); +// passes.add(TransformNonConstantFactToRule); +// passes.add(DesugarArgTypeAdornment::new()); +// passes.add(DesugarCaseIs::new()); +// passes.add(DesugarForallExists::new()); +// passes.add(DesugarRange::new()); +// passes.add(DesugarReduceRule::new(&analysis.type_inference.foreign_aggregate_type_registry)); +// passes.add(TransformForall); +// passes.add(TransformImplies); + +// passes.add(TransformConstVarToConst2::new(&analysis.constant_decl_analysis)); +// // passes.add(DesugarDestruct::new()); + +// passes + +// } + +// pub fn add + 'a>(&mut self, transformation: T) { +// // Get the name and dependencies +// let name = T::name(); +// let dependencies = transformation.dependencies(); + +// // Add the transformation into the graph and the id map +// let id = self.transformations.add_node(DynTransformation { transf: Box::new(transformation) }); +// self.transformation_ids.insert(name, id.clone()); + +// // Add the dependency edges +// for dep_name in dependencies { +// let dep_id = self.transformation_ids.get(dep_name).expect(&format!("When adding front-compile pass `{name}`, dependent pass `{dep_name}` does not exist")); +// self.transformations.add_edge(dep_id.clone(), id, ()); +// } +// } + +// fn stratumize_transformations(&self) -> Vec> { +// use petgraph::visit::EdgeRef; +// let sorted_nodes = petgraph::algo::toposort(&self.transformations, None).expect("Cycle found in front-compile passes. Aborting"); +// let mut stratums: Vec> = Vec::new(); +// let mut node_to_stratum_map: HashMap = HashMap::new(); +// for node in sorted_nodes { +// // Compute what stratum to put the pass in +// let to_put_stratum = self +// .transformations +// .edges_directed(node, petgraph::Direction::Incoming) +// .map(|edge| { +// let source_stratum_id = node_to_stratum_map.get(&edge.source()).expect("Should contain edge.source"); +// source_stratum_id + 1 +// }) +// .max() +// .unwrap_or(0); + +// // Put the pass into the stratum +// if to_put_stratum >= stratums.len() { +// stratums.push(vec![node]); +// } else { +// stratums[to_put_stratum].push(node); +// } + +// // Record the information in `node_to_stratum_map` +// node_to_stratum_map.insert(node, to_put_stratum); +// } + +// stratums +// } + +// /// Run all the transformations +// pub fn run(mut self, ast: &mut Vec) { +// // Stratumize transformations so that transformations can be ran in as minimum iterations as possible +// let stratums = self.stratumize_transformations(); + +// // Construct a node index to transformation mapping; we are doing this because it is not possible to directly +// // access this information through graph API +// let mut pass_id_to_pass_map = HashMap::new(); +// for pass_id in self.transformations.node_indices().rev() { +// pass_id_to_pass_map.insert(pass_id, self.transformations.remove_node(pass_id).expect("Should be expected")); +// } + +// // Iterate through stratums +// for stratum in stratums { +// // Construct all the passes in this stratum +// let mut to_run_passes = vec![]; +// for pass_id in stratum { +// let pass = pass_id_to_pass_map.remove(&pass_id).expect("Should be expected"); +// to_run_passes.push(pass); +// } + +// // Walk the AST with the set of passes +// ast.walk_mut(&mut to_run_passes); + +// // Check if any item needs to be removed +// ast.retain(|item| !to_run_passes.iter().any(|pass| pass.post_walking_removes_item(item))); + +// // Check if any item needs to be added +// for pass in &mut to_run_passes { +// ast.extend(pass.post_walking_generated_items()); +// } +// } +// } +// } + +// pub fn apply_transformations(ast: &mut Vec, analysis: &mut Analysis) { +// let transf_passes = Transformations::std(analysis); +// transf_passes.run(ast); +// } diff --git a/core/src/compiler/front/transformations/adt_to_relation.rs b/core/src/compiler/front/transformations/adt_to_relation.rs index 6780316..62ba526 100644 --- a/core/src/compiler/front/transformations/adt_to_relation.rs +++ b/core/src/compiler/front/transformations/adt_to_relation.rs @@ -6,6 +6,16 @@ pub struct TransformAlgebraicDataType<'a> { analysis: &'a mut AlgebraicDataTypeAnalysis, } +impl<'a> Transformation<'a> for TransformAlgebraicDataType<'a> { + fn post_walking_removes_item(&self, item: &Item) -> bool { + !self.retain(item) + } + + fn post_walking_generated_items(&mut self) -> Vec { + self.generate_items() + } +} + impl<'a> NodeVisitor for TransformAlgebraicDataType<'a> { fn visit_mut(&mut self, type_decl: &mut TypeDecl) { match type_decl { @@ -25,7 +35,7 @@ impl<'a> TransformAlgebraicDataType<'a> { Self { analysis } } - pub fn generate_items(self) -> Vec { + pub fn generate_items(&mut self) -> Vec { let result = self .analysis .adt_variants diff --git a/core/src/compiler/front/transformations/atomic_query.rs b/core/src/compiler/front/transformations/atomic_query.rs index e1fe78c..ec2badf 100644 --- a/core/src/compiler/front/transformations/atomic_query.rs +++ b/core/src/compiler/front/transformations/atomic_query.rs @@ -1,3 +1,4 @@ +use super::*; use crate::compiler::front::*; #[derive(Clone, Debug)] @@ -5,17 +6,23 @@ pub struct TransformAtomicQuery { pub to_add_rules: Vec, } +impl<'a> Transformation<'a> for TransformAtomicQuery { + fn post_walking_generated_items(&mut self) -> Vec { + self.drain_items() + } +} + impl TransformAtomicQuery { pub fn new() -> Self { Self { to_add_rules: vec![] } } - pub fn drain_items(self) -> Vec { + pub fn drain_items(&self) -> Vec { self .to_add_rules .iter() .map(|rule| { - let rule_decl = RuleDecl::new(vec![], Tag::none(), rule.clone()); + let rule_decl = RuleDecl::new(vec![], None, rule.clone()); let rel_decl = RelationDecl::Rule(rule_decl); let item = Item::RelationDecl(rel_decl); item diff --git a/core/src/compiler/front/transformations/conjunctive_head.rs b/core/src/compiler/front/transformations/conjunctive_head.rs index 897bfcd..e395902 100644 --- a/core/src/compiler/front/transformations/conjunctive_head.rs +++ b/core/src/compiler/front/transformations/conjunctive_head.rs @@ -5,6 +5,12 @@ pub struct TransformConjunctiveHead { to_add_items: Vec, } +impl<'a> Transformation<'a> for TransformConjunctiveHead { + fn post_walking_generated_items(&mut self) -> Vec { + self.to_add_items.clone() + } +} + impl TransformConjunctiveHead { pub fn new() -> Self { Self { to_add_items: vec![] } @@ -37,7 +43,7 @@ impl NodeVisitor for TransformConjunctiveHead { .to_add_items .push(Item::RelationDecl(RelationDecl::Rule(RuleDecl::new( Attributes::new(), - Tag::none(), + None, Rule::new_with_loc( RuleHead::atom(atom.clone()), rule.body().clone(), diff --git a/core/src/compiler/front/transformations/const_var_to_const.rs b/core/src/compiler/front/transformations/const_var_to_const.rs index 2182ccf..d868aa6 100644 --- a/core/src/compiler/front/transformations/const_var_to_const.rs +++ b/core/src/compiler/front/transformations/const_var_to_const.rs @@ -1,12 +1,49 @@ -use crate::{common::input_tag::DynamicInputTag, compiler::front::analyzers::ConstantDeclAnalysis}; +use crate::compiler::front::analyzers::ConstantDeclAnalysis; +use crate::compiler::front::transformations::TransformAlgebraicDataType; use super::super::*; +pub struct TransformConstVarToConst2<'a> { + parent: TransformConstVarToConst<'a>, +} + +impl<'a> TransformConstVarToConst2<'a> { + pub fn new(const_decl_analysis: &'a ConstantDeclAnalysis) -> Self { + Self { + parent: TransformConstVarToConst::new(const_decl_analysis), + } + } +} + +impl<'a> Transformation<'a> for TransformConstVarToConst2<'a> { + fn dependencies(&self) -> Vec<&'static str> { + vec![TransformAlgebraicDataType::name()] + } +} + +impl<'a> NodeVisitor for TransformConstVarToConst2<'a> { + fn visit_mut(&mut self, expr: &mut Expr) { + self.parent.visit_mut(expr); + } +} + +impl<'a> NodeVisitor for TransformConstVarToConst2<'a> { + fn visit_mut(&mut self, cov: &mut ConstantOrVariable) { + self.parent.visit_mut(cov); + } +} + #[derive(Clone, Debug)] pub struct TransformConstVarToConst<'a> { const_decl_analysis: &'a ConstantDeclAnalysis, } +impl<'a> Transformation<'a> for TransformConstVarToConst<'a> { + fn post_walking_generated_items(&mut self) -> Vec { + self.generate_items() + } +} + impl<'a> TransformConstVarToConst<'a> { pub fn new(const_decl_analysis: &'a ConstantDeclAnalysis) -> Self { Self { const_decl_analysis } @@ -20,7 +57,7 @@ impl<'a> TransformConstVarToConst<'a> { .map(|entity_fact| { Item::RelationDecl(RelationDecl::Fact(FactDecl::new( Attributes::new(), - Tag::new(DynamicInputTag::None), + None, Atom::new_with_loc( { entity_fact diff --git a/core/src/compiler/front/transformations/desugar_arg_type_anno.rs b/core/src/compiler/front/transformations/desugar_arg_type_anno.rs index 39cda88..5848b2b 100644 --- a/core/src/compiler/front/transformations/desugar_arg_type_anno.rs +++ b/core/src/compiler/front/transformations/desugar_arg_type_anno.rs @@ -5,9 +5,11 @@ pub struct DesugarArgTypeAdornment { new_items: Vec, } +impl<'a> Transformation<'a> for DesugarArgTypeAdornment {} + impl DesugarArgTypeAdornment { pub fn new() -> Self { - Self { new_items: Vec::new() } + Self { new_items: vec![] } } pub fn generate_demand_attribute(rel_type: &RelationType) -> Attribute { diff --git a/core/src/compiler/front/transformations/desugar_case_is.rs b/core/src/compiler/front/transformations/desugar_case_is.rs index 132b634..6a78171 100644 --- a/core/src/compiler/front/transformations/desugar_case_is.rs +++ b/core/src/compiler/front/transformations/desugar_case_is.rs @@ -4,6 +4,8 @@ use crate::utils::IdAllocator; #[derive(Clone, Debug)] pub struct DesugarCaseIs; +impl<'a> Transformation<'a> for DesugarCaseIs {} + impl DesugarCaseIs { pub fn new() -> Self { Self @@ -24,9 +26,7 @@ impl DesugarCaseIs { } Entity::Object(o) => { // If the entity is an object, the formula is a conjunction of atoms - let parent_id = case - .location_id() - .expect("Case location id is not populated prior to desugar case is transformation"); + let parent_id = case.variable_name(); let variable = case.variable().clone(); let mut variable_counter = IdAllocator::new(); let mut formulas = vec![]; @@ -44,7 +44,7 @@ impl DesugarCaseIs { &self, variable: Variable, object: &Object, - parent_id: usize, + parent_id: &String, variable_counter: &mut IdAllocator, formulas: &mut Vec, ) { @@ -60,7 +60,7 @@ impl DesugarCaseIs { let variable_id = variable_counter.alloc(); // Create a variable from the variable id - let current_variable = Variable::new(Identifier::new(format!("adt#var#{parent_id}#{variable_id}"))); + let current_variable = Variable::new(Identifier::new(format!("adt#var#({parent_id})#{variable_id}"))); // Recurse on the object self.transform_object_to_formula_helper(current_variable.clone(), o, parent_id, variable_counter, formulas); diff --git a/core/src/compiler/front/transformations/desugar_destruct.rs b/core/src/compiler/front/transformations/desugar_destruct.rs new file mode 100644 index 0000000..720c408 --- /dev/null +++ b/core/src/compiler/front/transformations/desugar_destruct.rs @@ -0,0 +1,146 @@ +use crate::compiler::front::*; +use crate::utils::IdAllocator; + +#[derive(Clone, Debug)] +pub struct DesugarDestruct {} + +impl<'a> Transformation<'a> for DesugarDestruct {} + +impl DesugarDestruct { + pub fn new() -> Self { + Self {} + } + + pub fn has_destruct_in_atom(&self, atom: &Atom) -> bool { + atom.iter_args().any(|arg| arg.is_destruct()) + } + + pub fn transform_atom_with_destructor_to_formula(&self, atom: &Atom) -> (Atom, Vec) { + let mut variable_counter = IdAllocator::new(); + let mut all_desugared_formulas = vec![]; + let mut desugared_atom_args = vec![]; + + for arg in atom.iter_args() { + match arg { + Expr::Destruct(destruct) => { + let parent_id = destruct.location_id().expect("Destruct should have an ID"); + let variable = Variable::new(Identifier::new(format!("adt#destr#var#root#{parent_id}"))); + desugared_atom_args.push(Expr::Variable(variable.clone())); + self.transform_destruct_to_formula_helper( + variable, + destruct, + parent_id, + &mut variable_counter, + &mut all_desugared_formulas, + ); + } + _ => { + desugared_atom_args.push(arg.clone()); + } + } + } + + let desugared_atom = Atom::new_with_loc( + atom.predicate().clone(), + atom.type_args().clone(), + desugared_atom_args, + atom.location().clone_without_id(), + ); + + (desugared_atom, all_desugared_formulas) + } + + fn transform_destruct_to_formula_helper( + &self, + variable: Variable, + destruct: &DestructExpr, + parent_id: usize, + variable_counter: &mut IdAllocator, + formulas: &mut Vec, + ) { + // Obtain the predicate of the atom that we are going to generate + let predicate = destruct + .functor() + .clone_without_location_id() + .map(|n| format!("adt#{n}")); + + // Obtain the second-to-last arguments in the atom + let sub_args = destruct.iter_args().map(|arg| { + match &arg { + Expr::Destruct(o) => { + // Obtain a variable id + let variable_id = variable_counter.alloc(); + + // Create a variable from the variable id + let current_variable = Variable::new(Identifier::new(format!("adt#destr#var#{parent_id}#{variable_id}"))); + + // Recurse on the object + self.transform_destruct_to_formula_helper(current_variable.clone(), o, parent_id, variable_counter, formulas); + + // Return the variable as the result + Expr::Variable(current_variable) + } + _ => arg.clone(), + } + }); + + // Create all arguments including the variable + let args = std::iter::once(Expr::Variable(variable)).chain(sub_args).collect(); + + // Add a formula to the formulas + let formula = Formula::Atom(Atom::new(predicate, vec![], args)); + formulas.push(formula); + } +} + +impl NodeVisitor for DesugarDestruct { + fn visit_mut(&mut self, formula: &mut Formula) { + match formula { + Formula::Atom(a) => { + if self.has_destruct_in_atom(a) { + let (atom, rest) = self.transform_atom_with_destructor_to_formula(a); + *formula = Formula::conjunction(Conjunction::new( + std::iter::once(Formula::atom(atom)).chain(rest.into_iter()).collect(), + )); + } + } + _ => {} + } + } +} + +impl NodeVisitor for DesugarDestruct { + fn visit_mut(&mut self, rule: &mut Rule) { + match rule.head_mut() { + RuleHead::Atom(a) => { + if self.has_destruct_in_atom(a) { + let (atom, rest) = self.transform_atom_with_destructor_to_formula(a); + *a = atom; + *rule.body_mut() = Formula::conjunction(Conjunction::new( + std::iter::once(rule.body().clone()).chain(rest.into_iter()).collect(), + )); + } + } + RuleHead::Conjunction(conj_head) => { + if conj_head.iter_atoms().any(|atom| self.has_destruct_in_atom(atom)) { + panic!("[Consider report this bug] Conjunction head should be handled by a prior transformation pass") + } + } + RuleHead::Disjunction(disj_head) => { + if disj_head.iter_atoms().any(|atom| self.has_destruct_in_atom(atom)) { + unimplemented!() + } + } + } + } +} + +impl NodeVisitor for DesugarDestruct { + fn visit(&mut self, fact_decl: &FactDecl) { + if self.has_destruct_in_atom(fact_decl.atom()) { + panic!( + "[Consider report this bug] Fact declaration with destructor should be handled by a prior transformation pass" + ) + } + } +} diff --git a/core/src/compiler/front/transformations/desugar_forall_exists.rs b/core/src/compiler/front/transformations/desugar_forall_exists.rs index 6c7af44..23ef71b 100644 --- a/core/src/compiler/front/transformations/desugar_forall_exists.rs +++ b/core/src/compiler/front/transformations/desugar_forall_exists.rs @@ -16,6 +16,8 @@ use crate::compiler::front::*; #[derive(Clone, Debug)] pub struct DesugarForallExists; +impl<'a> Transformation<'a> for DesugarForallExists {} + impl DesugarForallExists { pub fn new() -> Self { Self diff --git a/core/src/compiler/front/transformations/desugar_range.rs b/core/src/compiler/front/transformations/desugar_range.rs index 7c80222..33269aa 100644 --- a/core/src/compiler/front/transformations/desugar_range.rs +++ b/core/src/compiler/front/transformations/desugar_range.rs @@ -16,6 +16,8 @@ use crate::compiler::front::*; #[derive(Clone, Debug)] pub struct DesugarRange; +impl<'a> Transformation<'a> for DesugarRange {} + impl DesugarRange { pub fn new() -> Self { Self diff --git a/core/src/compiler/front/transformations/desugar_reduce_rule.rs b/core/src/compiler/front/transformations/desugar_reduce_rule.rs index 1dbbed6..9775b4c 100644 --- a/core/src/compiler/front/transformations/desugar_reduce_rule.rs +++ b/core/src/compiler/front/transformations/desugar_reduce_rule.rs @@ -6,6 +6,8 @@ pub struct DesugarReduceRule { aggregate_types: AggregateTypeRegistry, } +impl<'a> Transformation<'a> for DesugarReduceRule {} + impl DesugarReduceRule { pub fn new(agg_ty_registry: &AggregateTypeRegistry) -> Self { Self { @@ -57,7 +59,7 @@ impl DesugarReduceRule { // Generate the whole rule let rule = Rule::new(generated_head, Formula::Reduce(generated_aggregate)); - let rule_decl = RuleDecl::new_with_loc(attrs.clone(), Tag::none(), rule, decl_loc.clone()); + let rule_decl = RuleDecl::new_with_loc(attrs.clone(), None, rule, decl_loc.clone()); // Return Some(rule_decl) diff --git a/core/src/compiler/front/transformations/forall_to_not_exists.rs b/core/src/compiler/front/transformations/forall_to_not_exists.rs index e06536c..47d0d7d 100644 --- a/core/src/compiler/front/transformations/forall_to_not_exists.rs +++ b/core/src/compiler/front/transformations/forall_to_not_exists.rs @@ -16,6 +16,8 @@ use crate::compiler::front::*; #[derive(Clone, Debug)] pub struct TransformForall; +impl<'a> Transformation<'a> for TransformForall {} + impl TransformForall { pub fn new() -> Self { Self diff --git a/core/src/compiler/front/transformations/implies_to_disjunction.rs b/core/src/compiler/front/transformations/implies_to_disjunction.rs index abdbffa..876a547 100644 --- a/core/src/compiler/front/transformations/implies_to_disjunction.rs +++ b/core/src/compiler/front/transformations/implies_to_disjunction.rs @@ -4,6 +4,8 @@ use crate::compiler::front::*; #[derive(Clone, Debug)] pub struct TransformImplies; +impl<'a> Transformation<'a> for TransformImplies {} + impl NodeVisitor for TransformImplies { fn visit_mut(&mut self, formula: &mut Formula) { match formula { diff --git a/core/src/compiler/front/transformations/mod.rs b/core/src/compiler/front/transformations/mod.rs index 68386b6..c3d1af4 100644 --- a/core/src/compiler/front/transformations/mod.rs +++ b/core/src/compiler/front/transformations/mod.rs @@ -4,6 +4,7 @@ mod conjunctive_head; mod const_var_to_const; mod desugar_arg_type_anno; mod desugar_case_is; +mod desugar_destruct; mod desugar_forall_exists; mod desugar_range; mod desugar_reduce_rule; @@ -18,6 +19,7 @@ pub use conjunctive_head::*; pub use const_var_to_const::*; pub use desugar_arg_type_anno::*; pub use desugar_case_is::*; +pub use desugar_destruct::*; pub use desugar_forall_exists::*; pub use desugar_range::*; pub use desugar_reduce_rule::*; @@ -25,3 +27,5 @@ pub use forall_to_not_exists::*; pub use implies_to_disjunction::*; pub use non_constant_fact_to_rule::*; pub use tagged_rule::*; + +use super::transform::*; diff --git a/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs b/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs index 7dd981d..61af212 100644 --- a/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs +++ b/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs @@ -3,11 +3,31 @@ use crate::compiler::front::*; #[derive(Clone, Debug)] pub struct TransformNonConstantFactToRule; +impl<'a> Transformation<'a> for TransformNonConstantFactToRule {} + +impl TransformNonConstantFactToRule { + pub fn destruct_expr_to_object(&self, destruct: &DestructExpr) -> Object { + Object::new_with_loc( + destruct.functor().clone(), + destruct.iter_args().map(|arg| self.expr_to_entity(arg)).collect(), + destruct.location().clone(), + ) + } + + pub fn expr_to_entity(&self, expr: &Expr) -> Entity { + match expr { + Expr::Destruct(d) => Entity::Object(self.destruct_expr_to_object(d)), + e => Entity::Expr(e.clone()), + } + } +} + impl NodeVisitor for TransformNonConstantFactToRule { fn visit_mut(&mut self, relation_decl: &mut RelationDecl) { // First collect the expressions in the fact that is not constant - let (attrs, tag, head, non_const_var_expr_pairs) = match &relation_decl { + let (loc, attrs, tag, head, non_const_var_expr_pairs) = match &relation_decl { RelationDecl::Fact(f) => { + let loc = f.location(); let attrs = f.attrs().clone(); let tag = f.tag().clone(); let head = f.atom().clone(); @@ -16,7 +36,7 @@ impl NodeVisitor for TransformNonConstantFactToRule { .enumerate() .filter_map(|(i, e)| if e.is_constant() { None } else { Some((i, e.clone())) }) .collect::>(); - (attrs, tag, head, non_const) + (loc, attrs, tag, head, non_const) } _ => return, }; @@ -50,16 +70,23 @@ impl NodeVisitor for TransformNonConstantFactToRule { let eq_consts = non_const_var_expr_pairs .into_iter() .map(|(i, e)| { - let var_expr = Expr::variable(Variable::new(Identifier::new(format!("fnc#{}", i)))); - let eq_expr = Expr::binary(BinaryExpr::new(BinaryOp::new_eq(), var_expr, e)); - Formula::Constraint(Constraint::new(eq_expr)) + if e.is_destruct() { + let var = Variable::new(Identifier::new(format!("fnc#{}", i))); + // Unwrap is okay because two lines before we checked that e is destruct + let entity = self.expr_to_entity(&e); + Formula::case(Case::new(var, entity)) + } else { + let var_expr = Expr::variable(Variable::new(Identifier::new(format!("fnc#{}", i)))); + let eq_expr = Expr::binary(BinaryExpr::new(BinaryOp::new_eq(), var_expr, e)); + Formula::Constraint(Constraint::new(eq_expr)) + } }) .collect::>(); let body = Formula::conjunction(Conjunction::new(eq_consts)); // Finally, generate a rule declaration let rule = Rule::new(head, body); - let rule_decl = RuleDecl::new(attrs, tag, rule); + let rule_decl = RuleDecl::new_with_loc(attrs, tag, rule, loc.clone()); // Modify the original relation declaration *relation_decl = RelationDecl::rule(rule_decl); diff --git a/core/src/compiler/front/transformations/tagged_rule.rs b/core/src/compiler/front/transformations/tagged_rule.rs index 67c8647..267585d 100644 --- a/core/src/compiler/front/transformations/tagged_rule.rs +++ b/core/src/compiler/front/transformations/tagged_rule.rs @@ -1,60 +1,67 @@ -use crate::common::input_tag::*; use crate::compiler::front::*; -#[derive(Clone, Debug)] -pub struct TransformTaggedRule { - pub to_add_tags: Vec<(String, DynamicInputTag)>, +#[derive(Debug)] +pub struct TransformTaggedRule<'a> { + pub tagged_rule_analysis: &'a mut analyzers::TaggedRuleAnalysis, } -impl TransformTaggedRule { - pub fn new() -> Self { - Self { to_add_tags: vec![] } +impl<'a> TransformTaggedRule<'a> { + pub fn new(tagged_rule_analysis: &'a mut analyzers::TaggedRuleAnalysis) -> Self { + Self { tagged_rule_analysis } } pub fn has_prob_attr(rule_decl: &RuleDecl) -> bool { rule_decl.attrs().iter().find(|a| a.name().name() == "tagged").is_some() } - - pub fn transform(rule_decl: &mut RuleDecl) -> String { - // 1. Generate the predicate - let pred = rule_decl.rule_tag_predicate(); - - // 2. Append the atom to the end - let new_atom = Formula::Atom(Atom::new(Identifier::new(pred.clone()), vec![], vec![])); - let new_body = Formula::Conjunction(Conjunction::new(vec![new_atom, rule_decl.rule().body().clone()])); - *rule_decl.rule_mut().body_mut() = new_body; - - // Return the predicate - pred - } - - pub fn drain_items(self) -> Vec { - self - .to_add_tags - .into_iter() - .map(|(pred, tag)| { - let fact = Atom::new(Identifier::new(pred.clone()), vec![], vec![]); - let fact_decl = FactDecl::new(vec![], Tag::new(tag), fact); - let rel_decl = RelationDecl::Fact(fact_decl); - let item = Item::RelationDecl(rel_decl); - item - }) - .collect() - } } -impl NodeVisitor for TransformTaggedRule { +impl<'a> NodeVisitor for TransformTaggedRule<'a> { fn visit_mut(&mut self, rule_decl: &mut RuleDecl) { // If rule is directly declared with probability - if rule_decl.tag().is_some() { + if let Some(tag) = rule_decl.tag().clone() { // Transform the rule - let pred = Self::transform(rule_decl); + let pred = rule_decl.rule_tag_predicate(); + + // We create a new variable to hold the tag + let tag_var_name = format!("{pred}#var"); + let tag_var = Variable::new(Identifier::new(tag_var_name.clone())); + let tag_var_expr = Expr::variable(tag_var); - // Store this probability for later - self.to_add_tags.push((pred.clone(), rule_decl.tag().tag().clone())); + // We generate a constraint encoding that `$variable == $tag` + let eq_constraint = Formula::constraint(Constraint::new(Expr::binary(BinaryExpr::new( + BinaryOp::new_eq(), + tag_var_expr.clone(), + tag.clone(), + )))); + + // Generate the foreign predicate atom with that tag variable as the only argument + let atom = Atom::new(Identifier::new(pred.clone()), vec![], vec![tag_var_expr]); + let atom_formula = Formula::atom(atom); + + // Generate a formula that is the conjunction of constraint and atom + let to_add_formula = Formula::conjunction(Conjunction::new(vec![eq_constraint, atom_formula])); + + // Update the original rule body + let new_body = Formula::Conjunction(Conjunction::new(vec![to_add_formula, rule_decl.rule().body().clone()])); + *rule_decl.rule_mut().body_mut() = new_body; + + // Remove the rule tag surface syntax + *rule_decl.tag_mut() = None; + + // Tell the analyzer to store the information + let rule_id = rule_decl.rule().location().clone(); + self + .tagged_rule_analysis + .add_tag_predicate(rule_id, pred, tag_var_name, tag.location().clone()); } else if Self::has_prob_attr(rule_decl) { - // If the rule is annotated with `@probabilistic` - Self::transform(rule_decl); + // Handle rules with external probabilities + + // If the rule is annotated with `@tagged`, we simply append a nullary atom at the end. + // The fact will be populated by external sources. + let pred = rule_decl.rule_tag_predicate(); + let new_atom = Formula::Atom(Atom::new(Identifier::new(pred.clone()), vec![], vec![])); + let new_body = Formula::Conjunction(Conjunction::new(vec![new_atom, rule_decl.rule().body().clone()])); + *rule_decl.rule_mut().body_mut() = new_body; } } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 3a35eac..eda4c3c 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,12 +1,13 @@ #![feature(min_specialization)] #![feature(extract_if)] #![feature(hash_extract_if)] -#![feature(iter_repeat_n)] #![feature(proc_macro_span)] pub mod common; pub mod compiler; pub mod integrate; pub mod runtime; -pub mod testing; pub mod utils; + +// Testing utilities +pub mod testing; diff --git a/core/src/runtime/dynamic/dataflow/batching/batches.rs b/core/src/runtime/dynamic/dataflow/batching/batches.rs index 6c9ab34..75b490b 100644 --- a/core/src/runtime/dynamic/dataflow/batching/batches.rs +++ b/core/src/runtime/dynamic/dataflow/batching/batches.rs @@ -73,6 +73,7 @@ impl<'a, Prov: Provenance> Batches<'a, Prov> for SingleBatch<'a, Prov> { } } +#[allow(unused)] #[derive(Clone)] pub struct DynamicBatchesOptional<'a, Prov: Provenance> { optional_batches: Option>, diff --git a/core/src/runtime/env/environment.rs b/core/src/runtime/env/environment.rs index f95c320..65b7c59 100644 --- a/core/src/runtime/env/environment.rs +++ b/core/src/runtime/env/environment.rs @@ -275,18 +275,18 @@ impl RuntimeEnvironment { // Compute result let result = match (&expr.op, lhs_v, rhs_v) { // Addition - (Add, Tuple::Value(I8(i1)), Tuple::Value(I8(i2))) => Tuple::Value(I8(i1 + i2)), - (Add, Tuple::Value(I16(i1)), Tuple::Value(I16(i2))) => Tuple::Value(I16(i1 + i2)), - (Add, Tuple::Value(I32(i1)), Tuple::Value(I32(i2))) => Tuple::Value(I32(i1 + i2)), - (Add, Tuple::Value(I64(i1)), Tuple::Value(I64(i2))) => Tuple::Value(I64(i1 + i2)), - (Add, Tuple::Value(I128(i1)), Tuple::Value(I128(i2))) => Tuple::Value(I128(i1 + i2)), - (Add, Tuple::Value(ISize(i1)), Tuple::Value(ISize(i2))) => Tuple::Value(ISize(i1 + i2)), - (Add, Tuple::Value(U8(i1)), Tuple::Value(U8(i2))) => Tuple::Value(U8(i1 + i2)), - (Add, Tuple::Value(U16(i1)), Tuple::Value(U16(i2))) => Tuple::Value(U16(i1 + i2)), - (Add, Tuple::Value(U32(i1)), Tuple::Value(U32(i2))) => Tuple::Value(U32(i1 + i2)), - (Add, Tuple::Value(U64(i1)), Tuple::Value(U64(i2))) => Tuple::Value(U64(i1 + i2)), - (Add, Tuple::Value(U128(i1)), Tuple::Value(U128(i2))) => Tuple::Value(U128(i1 + i2)), - (Add, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1 + i2)), + (Add, Tuple::Value(I8(i1)), Tuple::Value(I8(i2))) => Tuple::Value(I8(i1.wrapping_add(i2))), + (Add, Tuple::Value(I16(i1)), Tuple::Value(I16(i2))) => Tuple::Value(I16(i1.wrapping_add(i2))), + (Add, Tuple::Value(I32(i1)), Tuple::Value(I32(i2))) => Tuple::Value(I32(i1.wrapping_add(i2))), + (Add, Tuple::Value(I64(i1)), Tuple::Value(I64(i2))) => Tuple::Value(I64(i1.wrapping_add(i2))), + (Add, Tuple::Value(I128(i1)), Tuple::Value(I128(i2))) => Tuple::Value(I128(i1.wrapping_add(i2))), + (Add, Tuple::Value(ISize(i1)), Tuple::Value(ISize(i2))) => Tuple::Value(ISize(i1.wrapping_add(i2))), + (Add, Tuple::Value(U8(i1)), Tuple::Value(U8(i2))) => Tuple::Value(U8(i1.wrapping_add(i2))), + (Add, Tuple::Value(U16(i1)), Tuple::Value(U16(i2))) => Tuple::Value(U16(i1.wrapping_add(i2))), + (Add, Tuple::Value(U32(i1)), Tuple::Value(U32(i2))) => Tuple::Value(U32(i1.wrapping_add(i2))), + (Add, Tuple::Value(U64(i1)), Tuple::Value(U64(i2))) => Tuple::Value(U64(i1.wrapping_add(i2))), + (Add, Tuple::Value(U128(i1)), Tuple::Value(U128(i2))) => Tuple::Value(U128(i1.wrapping_add(i2))), + (Add, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1.wrapping_add(i2))), (Add, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(F32(i1 + i2)), (Add, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 + i2)), (Add, Tuple::Value(String(s1)), Tuple::Value(String(s2))) => Tuple::Value(String(format!("{}{}", s1, s2))), @@ -305,18 +305,18 @@ impl RuntimeEnvironment { (Add, b1, b2) => panic!("Cannot perform ADD on {:?} and {:?}", b1, b2), // Subtraction - (Sub, Tuple::Value(I8(i1)), Tuple::Value(I8(i2))) => Tuple::Value(I8(i1 - i2)), - (Sub, Tuple::Value(I16(i1)), Tuple::Value(I16(i2))) => Tuple::Value(I16(i1 - i2)), - (Sub, Tuple::Value(I32(i1)), Tuple::Value(I32(i2))) => Tuple::Value(I32(i1 - i2)), - (Sub, Tuple::Value(I64(i1)), Tuple::Value(I64(i2))) => Tuple::Value(I64(i1 - i2)), - (Sub, Tuple::Value(I128(i1)), Tuple::Value(I128(i2))) => Tuple::Value(I128(i1 - i2)), - (Sub, Tuple::Value(ISize(i1)), Tuple::Value(ISize(i2))) => Tuple::Value(ISize(i1 - i2)), - (Sub, Tuple::Value(U8(i1)), Tuple::Value(U8(i2))) => Tuple::Value(U8(i1 - i2)), - (Sub, Tuple::Value(U16(i1)), Tuple::Value(U16(i2))) => Tuple::Value(U16(i1 - i2)), - (Sub, Tuple::Value(U32(i1)), Tuple::Value(U32(i2))) => Tuple::Value(U32(i1 - i2)), - (Sub, Tuple::Value(U64(i1)), Tuple::Value(U64(i2))) => Tuple::Value(U64(i1 - i2)), - (Sub, Tuple::Value(U128(i1)), Tuple::Value(U128(i2))) => Tuple::Value(U128(i1 - i2)), - (Sub, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1 - i2)), + (Sub, Tuple::Value(I8(i1)), Tuple::Value(I8(i2))) => Tuple::Value(I8(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(I16(i1)), Tuple::Value(I16(i2))) => Tuple::Value(I16(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(I32(i1)), Tuple::Value(I32(i2))) => Tuple::Value(I32(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(I64(i1)), Tuple::Value(I64(i2))) => Tuple::Value(I64(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(I128(i1)), Tuple::Value(I128(i2))) => Tuple::Value(I128(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(ISize(i1)), Tuple::Value(ISize(i2))) => Tuple::Value(ISize(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(U8(i1)), Tuple::Value(U8(i2))) => Tuple::Value(U8(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(U16(i1)), Tuple::Value(U16(i2))) => Tuple::Value(U16(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(U32(i1)), Tuple::Value(U32(i2))) => Tuple::Value(U32(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(U64(i1)), Tuple::Value(U64(i2))) => Tuple::Value(U64(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(U128(i1)), Tuple::Value(U128(i2))) => Tuple::Value(U128(i1.wrapping_sub(i2))), + (Sub, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1.wrapping_sub(i2))), (Sub, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(F32(i1 - i2)), (Sub, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 - i2)), (Sub, Tuple::Value(DateTime(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(DateTime(i1 - i2)), @@ -650,6 +650,9 @@ impl RuntimeEnvironment { (Tuple::Value(String(s)), T::F32) => s.parse().ok().map(|i| Tuple::Value(F32(i))), (Tuple::Value(String(s)), T::F64) => s.parse().ok().map(|i| Tuple::Value(F64(i))), + (Tuple::Value(F32(f)), T::F64) => Some(Tuple::Value(F64(f as f64))), + (Tuple::Value(F64(f)), T::F32) => Some(Tuple::Value(F32(f as f32))), + (Tuple::Value(I8(i)), T::String) => Some(Tuple::Value(String(i.to_string()))), (Tuple::Value(I16(i)), T::String) => Some(Tuple::Value(String(i.to_string()))), (Tuple::Value(I32(i)), T::String) => Some(Tuple::Value(String(i.to_string()))), diff --git a/core/src/runtime/provenance/common/output_diff_prob_with_proofs.rs b/core/src/runtime/provenance/common/output_diff_prob_with_proofs.rs index a1dd18f..3c5bd3c 100644 --- a/core/src/runtime/provenance/common/output_diff_prob_with_proofs.rs +++ b/core/src/runtime/provenance/common/output_diff_prob_with_proofs.rs @@ -9,7 +9,13 @@ impl std::fmt::Debug for OutputDiffProbWithProofs { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") .field(&self.probability) - .field(&self.gradient.iter().map(|(id, weight)| (id, weight)).collect::>()) + .field( + &self + .gradient + .iter() + .map(|(id, weight)| (id, weight)) + .collect::>(), + ) .finish() } } @@ -18,7 +24,13 @@ impl std::fmt::Display for OutputDiffProbWithProofs { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") .field(&self.probability) - .field(&self.gradient.iter().map(|(id, weight)| (id, weight)).collect::>()) + .field( + &self + .gradient + .iter() + .map(|(id, weight)| (id, weight)) + .collect::>(), + ) .finish() } } diff --git a/core/src/runtime/provenance/differentiable/diff_top_k_proofs_debug.rs b/core/src/runtime/provenance/differentiable/diff_top_k_proofs_debug.rs index 7f7c99d..78021c5 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_k_proofs_debug.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_k_proofs_debug.rs @@ -112,11 +112,9 @@ impl Provenance for DiffTopKProofsDebugProvenan .map(|clause| { clause .iter() - .map(|literal| { - match literal { - Literal::Pos(id) => (true, *id), - Literal::Neg(id) => (false, *id), - } + .map(|literal| match literal { + Literal::Pos(id) => (true, *id), + Literal::Neg(id) => (false, *id), }) .collect() }) diff --git a/core/tests/compiler/adt.rs b/core/tests/compiler/adt.rs index 024a6f9..c2bb1ba 100644 --- a/core/tests/compiler/adt.rs +++ b/core/tests/compiler/adt.rs @@ -170,3 +170,14 @@ fn adt_add_dynamic_entity_2() { vec![(1i32, 10i32), (2, 3)], ); } + +#[test] +fn parse_destructor_1() { + expect_compile( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + rel eval(Const(i), i) + rel eval(Add(e1, e2), i1 + i2) = eval(e1, i1) and eval(e2, i2) + "#, + ) +} diff --git a/core/tests/integrate/adt.rs b/core/tests/integrate/adt.rs index 30cfe04..0001050 100644 --- a/core/tests/integrate/adt.rs +++ b/core/tests/integrate/adt.rs @@ -18,6 +18,20 @@ fn adt_arith_formula_eval_1() { ) } +#[test] +fn adt_arith_formula_eval_w_destruct_1() { + expect_interpret_result( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + const MY_EXPR = Add(Const(5), Add(Const(3), Const(6))) + rel eval(Const(y), y) + rel eval(Add(x1, x2), y1 + y2) = eval(x1, y1) and eval(x2, y2) + rel result(y) = eval(MY_EXPR, y) + "#, + ("result", vec![(14i32,)]), + ) +} + #[test] fn adt_list_1() { expect_interpret_result( @@ -35,6 +49,40 @@ fn adt_list_1() { ) } +#[test] +fn adt_list_w_destruct_1() { + expect_interpret_result( + r#" + type List = Nil() | Cons(i32, List) + + const MY_LIST = Cons(1, Cons(2, Cons(3, Nil()))) + + rel list_sum(Nil(), 0) + rel list_sum(Cons(hd, tl), hd + s) = list_sum(tl, s) + + rel result(y) = list_sum(MY_LIST, y) + "#, + ("result", vec![(6i32,)]), + ) +} + +#[test] +fn destruct_in_body_atom_1() { + expect_interpret_result( + r#" + type List = Nil() | Cons(i32, List) + + const EMPTY = Nil() + + rel print(Nil(), "[]") + rel print(Cons(x, Nil()), $format("[{}]", x)) + + rel empty_array_string(x) = print(Nil(), x) + "#, + ("empty_array_string", vec![("[]".to_string(),)]), + ) +} + #[test] fn adt_binary_tree_1() { expect_interpret_result( @@ -53,6 +101,23 @@ fn adt_binary_tree_1() { ) } +#[test] +fn adt_binary_tree_w_destruct_1() { + expect_interpret_result( + r#" + type Tree = Nil() | Node(i32, Tree, Tree) + + rel tree_depth(Nil(), 0) + rel tree_depth(Node(_, lt, rt), $max(ld, rd) + 1) = tree_depth(lt, ld) and tree_depth(rt, rd) + + const MY_TREE = Node(1, Node(2, Nil(), Node(3, Nil(), Nil())), Node(4, Nil(), Nil())) + + rel result(y) = tree_depth(MY_TREE, y) + "#, + ("result", vec![(3i32,)]), + ) +} + const RE_PROGRAM: &'static str = r#" type RE = Char(char) | Nil() | Con(RE, RE) | Or(RE, RE) | Star(RE) @@ -66,6 +131,19 @@ const RE_PROGRAM: &'static str = r#" rel match(r, s, e) = case r is Star(r1), match(r1, s, m), match(r, m, e) "#; +const RE_PROGRAM_W_DESTRUCT: &'static str = r#" + type RE = Char(char) | Nil() | Con(RE, RE) | Or(RE, RE) | Star(RE) + + rel match(Nil(), i, i) = string_chars(s, i, _), input_string(s) + rel match(Char(c), i, i + 1) = input_string(s), string_chars(s, i, c) + rel match(Con(r1, r2), s, e) = match(r1, s, m), match(r2, m, e) + rel match(Or(r1, r2), s, e) = match(r1, s, e) + rel match(Or(r1, r2), s, e) = match(r2, s, e) + rel match(Star(r1), i, i) = string_chars(s, i, _), input_string(s) + rel match(Star(r1), s, e) = match(r1, s, e) + rel match(Star(r2), s, e) = match(r1, s, m), match(r, m, e) +"#; + #[test] fn adt_regex_1() { expect_interpret_result( @@ -96,15 +174,45 @@ fn adt_regex_2() { ) } +#[test] +fn adt_regex_w_destruct_1() { + expect_interpret_result( + &format!( + "{RE_PROGRAM_W_DESTRUCT}\n{}", + r#" + const MY_RE = Con(Char('a'), Char('b')) + rel input_string("ab") + rel result() = match(MY_RE, 0, 2) + "#, + ), + ("result", vec![()]), + ) +} + +#[test] +fn adt_regex_w_destruct_2() { + expect_interpret_result( + &format!( + "{RE_PROGRAM_W_DESTRUCT}\n{}", + r#" + const MY_RE = Con(Star(Char('a')), Char('b')) + rel input_string("aaaaaaaab") + rel result() = match(MY_RE, 0, 9) + "#, + ), + ("result", vec![()]), + ) +} + const CLEVR_PROGRAM: &'static str = r#" type Color = RED | GREEN | BLUE type Size = LARGE | SMALL type SpatialRela = LEFT | RIGHT type Expr = Scene() | Color(Color, Expr) | Size(Size, Expr) | Rela(SpatialRela, Expr, Expr) | RelaInv(SpatialRela, Expr, Expr) - rel eval(e, output_obj) = case e is Scene(), input_obj_ids(output_obj) - rel eval(e, output_obj) = case e is Color(c, e1), eval(e1, output_obj), input_obj_color(output_obj, c) - rel eval(e, output_obj) = case e is Size(s, e1), eval(e1, output_obj), input_obj_size(output_obj, s) + rel eval(e, o) = case e is Scene(), input_obj_ids(o) + rel eval(e, o) = case e is Color(c, e1), eval(e1, o), input_obj_color(o, c) + rel eval(e, o) = case e is Size(s, e1), eval(e1, o), input_obj_size(o, s) rel eval(e, o2) = case e is Rela(r, e1, e2), eval(e1, o1), eval(e2, o2), input_obj_rela(r, o1, o2) rel eval(e, o1) = case e is RelaInv(r, e1, e2), eval(e1, o1), eval(e2, o2), input_obj_rela(r, o1, o2) "#; @@ -136,9 +244,9 @@ const EQSAT_1_PROGRAM: &'static str = r#" | Add(Expr, Expr) // A relation `to_string` for visualizing - rel to_string(p, i as String) = case p is Const(i) - rel to_string(p, v) = case p is Var(v) - rel to_string(p, $format("({} + {})", s1, s2)) = case p is Add(p1, p2) and to_string(p1, s1) and to_string(p2, s2) + rel to_string(Const(i), i as String) + rel to_string(Var(v), v) + rel to_string(Add(p1, p2), $format("({} + {})", s1, s2)) = to_string(p1, s1) and to_string(p2, s2) // Relation for expression rel expr(p) = case p is Const(_) or case p is Var(_) or case p is Add(_, _) @@ -152,9 +260,9 @@ const EQSAT_1_PROGRAM: &'static str = r#" rel equivalent(p, p1) = case p is Add(p1, Const(0)) // Definition of weight - rel weight(p, 1) = case p is Const(_) - rel weight(p, 1) = case p is Var(_) - rel weight(p, w1 + w2 + 1) = case p is Add(p1, p2) and weight(p1, w1) and weight(p2, w2) + rel weight(Const(_), 1) + rel weight(Var(_), 1) + rel weight(Add(p1, p2), w1 + w2 + 1) = weight(p1, w1) and weight(p2, w2) // Compute equivalent programs rel equiv_programs(sp) = input_program(p) and equivalent(p, sp) diff --git a/core/tests/integrate/mod.rs b/core/tests/integrate/mod.rs index 8eae8f8..2ed930d 100644 --- a/core/tests/integrate/mod.rs +++ b/core/tests/integrate/mod.rs @@ -12,4 +12,5 @@ mod io; mod iter; mod prob; mod sampling; +mod tag; mod time; diff --git a/core/tests/integrate/tag.rs b/core/tests/integrate/tag.rs new file mode 100644 index 0000000..71551ea --- /dev/null +++ b/core/tests/integrate/tag.rs @@ -0,0 +1,134 @@ +use scallop_core::runtime::provenance; +use scallop_core::testing::*; + +#[test] +fn test_rule_constant_tag_simple_1() { + use provenance::add_mult_prob::*; + let prov_ctx = AddMultProbProvenance::default(); + expect_interpret_result_with_tag( + r#" + rel my_num(5) + rel 0.5::fall_off(n) = my_num(n) + "#, + prov_ctx, + ("fall_off", vec![(0.5, (5i32,))]), + AddMultProbProvenance::soft_cmp, + ) +} + +#[test] +fn test_rule_constant_integer_tag_simple_1() { + use provenance::add_mult_prob::*; + let prov_ctx = AddMultProbProvenance::default(); + expect_interpret_result_with_tag( + r#" + rel my_num(5) + rel 1::fall_off(n) = my_num(n) + "#, + prov_ctx, + ("fall_off", vec![(1.0, (5i32,))]), + AddMultProbProvenance::soft_cmp, + ) +} + +#[test] +fn test_multiple_rule_constant_tag_simple_1() { + use provenance::add_mult_prob::*; + let prov_ctx = AddMultProbProvenance::default(); + expect_interpret_result_with_tag( + r#" + rel my_num(5) + rel 0.5::fall_off(n / 2) = my_num(n) + rel 0.2::fall_off(n + 1) = my_num(n) + "#, + prov_ctx, + ("fall_off", vec![(0.5, (2i32,)), (0.2, (6i32,))]), + AddMultProbProvenance::soft_cmp, + ) +} + +#[test] +fn test_expr_tag_direct_propagate_1() { + use provenance::add_mult_prob::*; + let prov_ctx = AddMultProbProvenance::default(); + expect_interpret_result_with_tag( + r#" + rel my_prob(0.5) + rel p::fall_off() = my_prob(p) + "#, + prov_ctx, + ("fall_off", vec![(0.5, ())]), + AddMultProbProvenance::soft_cmp, + ) +} + +#[test] +fn test_expr_tag_simple_1() { + use provenance::add_mult_prob::*; + let prov_ctx = AddMultProbProvenance::default(); + expect_interpret_result_with_tag( + r#" + rel my_num(5.0) + rel 1.0/n::fall_off() = my_num(n) + "#, + prov_ctx, + ("fall_off", vec![(0.2, ())]), + AddMultProbProvenance::soft_cmp, + ) +} + +#[test] +fn test_expr_tag_simple_2() { + use provenance::min_max_prob::*; + let prov_ctx = MinMaxProbProvenance::default(); + expect_interpret_result_with_tag( + r#" + rel edge = {(0, 1), (1, 2)} + rel path(x, y, 1.0) = edge(x, y) + rel path(x, z, l + 1.0) = path(x, y, l) and edge(y, z) + rel (1.0 / l)::path_prob(x, y) = path(x, y, l) + "#, + prov_ctx, + ("path_prob", vec![(1.0, (0, 1)), (1.0, (1, 2)), (0.5, (0, 2))]), + MinMaxProbProvenance::cmp, + ) +} + +#[test] +fn test_expr_tag_type_error() { + // `1.0 / n` will fail because `n` is of type `i32` + expect_compile_failure( + r#" + type my_num(i32) + rel my_num(5) + rel (1.0 / n)::fall_off() = my_num(n) + "#, + |err| err.to_string().contains("type"), + ) +} + +#[test] +fn test_expr_tag_unbound_error() { + expect_compile_failure( + r#" + rel my_num() + rel (1.0 / n)::fall_off() = my_num() + "#, + |err| err.to_string().contains("bound"), + ) +} + +#[test] +fn test_expr_tag_cannot_be_datetime() { + expect_compile_failure( + r#" + rel my_time(t"2024-01-01") + rel t::fall_off() = my_time(t) + "#, + |err| { + err + .to_string() + .contains("A value of type `DateTime` cannot be casted into a dynamic tag") + }, + ) +} diff --git a/core/tests/tests.rs b/core/tests/tests.rs index 6378d05..38f4713 100644 --- a/core/tests/tests.rs +++ b/core/tests/tests.rs @@ -1,4 +1,4 @@ mod compiler; mod integrate; mod runtime; -mod utils; +mod unit; diff --git a/core/tests/utils/mod.rs b/core/tests/unit/mod.rs similarity index 100% rename from core/tests/utils/mod.rs rename to core/tests/unit/mod.rs diff --git a/core/tests/utils/value.rs b/core/tests/unit/value.rs similarity index 100% rename from core/tests/utils/value.rs rename to core/tests/unit/value.rs diff --git a/etc/codegen/Cargo.toml b/etc/codegen/Cargo.toml index 66fc2ea..6a333d5 100644 --- a/etc/codegen/Cargo.toml +++ b/etc/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-codegen" -version = "0.2.2" +version = "0.2.4" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/scallop-cli/setup.cfg b/etc/scallop-cli/setup.cfg index 27facf3..0c4d0ed 100644 --- a/etc/scallop-cli/setup.cfg +++ b/etc/scallop-cli/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = scallop -version = 0.2.2 +version = 0.2.4 author = Ziyang Li author_email = liby99@seas.upenn.edu description = Scallop CLI diff --git a/etc/scallop-wasm/Cargo.toml b/etc/scallop-wasm/Cargo.toml index ef77a34..021473a 100644 --- a/etc/scallop-wasm/Cargo.toml +++ b/etc/scallop-wasm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-wasm" -version = "0.2.2" +version = "0.2.4" authors = ["Ziyang Li"] edition = "2018" diff --git a/etc/scallopy-ext/setup.cfg b/etc/scallopy-ext/setup.cfg index d01a6a4..2625fb5 100644 --- a/etc/scallopy-ext/setup.cfg +++ b/etc/scallopy-ext/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = scallopy_ext -version = 0.2.2 +version = 0.2.4 author = Ziyang Li author_email = liby99@seas.upenn.edu description = Scallopy Extension diff --git a/etc/scallopy/Cargo.toml b/etc/scallopy/Cargo.toml index d32a2d9..41afa98 100644 --- a/etc/scallopy/Cargo.toml +++ b/etc/scallopy/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallopy" -version = "0.2.2" +version = "0.2.4" edition = "2018" [lib] diff --git a/etc/scallopy/examples/sparse_gradient.py b/etc/scallopy/examples/sparse_gradient.py new file mode 100644 index 0000000..361f1b6 --- /dev/null +++ b/etc/scallopy/examples/sparse_gradient.py @@ -0,0 +1,34 @@ +import torch +import scallopy + +ctx = scallopy.ScallopContext(provenance="difftopkproofs") +ctx.add_relation("digit_1", int, range(10)) +ctx.add_relation("digit_2", int, range(10)) +ctx.add_rule("sum_2(a + b) = digit_1(a) and digit_2(b)") +ctx.add_rule("mult_2(a * b) = digit_1(a) and digit_2(b)") + +loss_fn = torch.nn.BCELoss() +forward = ctx.forward_function( + "sum_2", list(range(19)), sparse_jacobian=True) + +# Construct the digit +digit_1_base = torch.randn((16, 10), requires_grad=True) +digit_1 = torch.softmax(digit_1_base, dim=1) +digit_2_base = torch.randn((16, 10), requires_grad=True) +digit_2 = torch.softmax(digit_2_base, dim=1) + +# Call scallop and obtain loss +sum_2 = forward(digit_1=digit_1, digit_2=digit_2) +gt = torch.tensor([[1.0] + [0.0] * 18] * 16) +l = loss_fn(sum_2, gt) + +# Ensure that there is no gradient +assert digit_1_base.grad == None +assert digit_2_base.grad == None + +# Perform backward +l.backward() + +# Ensure that there is some gradient +assert any(p != 0.0 for distr in digit_1_base.grad for p in distr) +assert any(p != 0.0 for distr in digit_2_base.grad for p in distr) diff --git a/etc/scallopy/scallopy/context.py b/etc/scallopy/scallopy/context.py index 5a242d2..351a5aa 100644 --- a/etc/scallopy/scallopy/context.py +++ b/etc/scallopy/scallopy/context.py @@ -330,6 +330,7 @@ def forward_function( jit: bool = False, jit_name: str = "", recompile: bool = False, + sparse_jacobian: bool = False, ) -> Callable: """ Generate a forward function for PyTorch module. @@ -357,7 +358,18 @@ def forward_function( else: raise Exception("`forward_function` can only be called on context with differentiable provenance") # Forward function - return InternalScallopForwardFunction(self, output, output_mapping, output_mappings, dispatch, debug_provenance, retain_graph, jit, jit_name, recompile) + return InternalScallopForwardFunction( + self, + output, + output_mapping, + output_mappings, + dispatch, + debug_provenance, + retain_graph, + jit, + jit_name, + recompile, + sparse_jacobian,) def _refresh_training_eval_state(self, training): if self._train_k is not None or self._test_k is not None: diff --git a/etc/scallopy/scallopy/forward.py b/etc/scallopy/scallopy/forward.py index 7344f28..2dec918 100644 --- a/etc/scallopy/scallopy/forward.py +++ b/etc/scallopy/scallopy/forward.py @@ -39,6 +39,7 @@ def __init__( jit_name: str = "", jit_recompile: bool = False, dispatch: str = "parallel", + sparse_jacobian: bool = False, monitors: List[str] = [], ): super(ScallopForwardFunction, self).__init__() @@ -111,6 +112,7 @@ def __init__( jit=jit, jit_name=jit_name, recompile=jit_recompile, + sparse_jacobian=sparse_jacobian, ) def __call__(self, *pos_args, **kw_args): @@ -137,6 +139,7 @@ def __init__( jit: bool = False, jit_name: str = "", recompile: bool = False, + sparse_jacobian: bool = False, ): super(InternalScallopForwardFunction, self).__init__() @@ -148,6 +151,7 @@ def __init__( self.jit = jit self.jit_name = jit_name self.recompile = recompile + self.sparse_jacobian = sparse_jacobian self.fn_counter = self.FORWARD_FN_COUNTER # Preprocess the dispatch @@ -278,15 +282,18 @@ def __call__( self, disjunctions: Dict[str, List[List[List[int]]]] = {}, output_relations: Optional[List[Union[str, List[str]]]] = None, + output_mappings: Union[List, Dict[str, List]] = None, **input_facts: Dict[str, Union[torch_importer.Tensor, List]], ) -> Union[torch_importer.Tensor, Tuple[List[Tuple], torch_importer.Tensor]]: """ Invoke the forward function with the given facts - The facts and disjunctions need to be batched + The `facts` and `disjunctions` need to be batched; and we assume the batch size + to be B - output_relations can be one of the following format - - None, if outputs are provided + `output_relations` can be one of the following format + - None, if outputs are provided when constructing the ForwardFunction + - [rela 1, rela 2, ..., rela B], if we want each data-point to produce different relations """ if self.jit: return self._call_with_static_ctx( @@ -296,6 +303,7 @@ def __call__( return self._call_with_dynamic_ctx( disjunctions=disjunctions, output_relations=output_relations, + output_mappings=output_mappings, input_facts=input_facts) def _call_with_static_ctx( @@ -313,18 +321,18 @@ def _call_with_static_ctx( # Execute static scallop program for each task from python input_tags, output_results = [], [] for task_id in range(batch_size): - (task_input_tags, task_output_results) = self._run_single_static(task_id, all_inputs) + (task_input_tags, task_output_results) = self._run_single_static(task_id, all_inputs, self.output_mappings) input_tags.append(task_input_tags) output_results.append(task_output_results) elif self.dispatch == "parallel": # Directly dispatch all the inputs to rust, and execute with parallel - (input_tags, output_results) = self._run_batch_static(batch_size, all_inputs, parallel=True) + (input_tags, output_results) = self._run_batch_static(batch_size, all_inputs, self.output_mappings, parallel=True) else: # Directly dispatch all the inputs to rust - (input_tags, output_results) = self._run_batch_static(batch_size, all_inputs, parallel=False) + (input_tags, output_results) = self._run_batch_static(batch_size, all_inputs, self.output_mappings, parallel=False) # Process the output - return self._process_output(batch_size, input_tags, output_results) + return self._process_output(batch_size, input_tags, output_results, self.output_mappings) def _get_k(self): if self.ctx._train_k is not None or self.ctx._test_k is not None: @@ -336,8 +344,9 @@ def _get_k(self): def _call_with_dynamic_ctx( self, disjunctions: Optional[Dict], - output_relations: Optional[List[Union[str, List[str]]]], + output_relations: Optional[Union[str, List[Union[str, List[str]]]]], input_facts: Dict[str, Union[torch_importer.Tensor, List]], + output_mappings: Optional[Union[List, Dict[str, List]]] = None, ): self.ctx._refresh_training_eval_state(self.training) # Set train/eval self.ctx._internal.set_non_incremental() @@ -350,12 +359,62 @@ def _call_with_dynamic_ctx( all_inputs = self._process_all_input_facts(batch_size, input_facts, disjunctions) # Process the output into a list of output relations - if output_relations is None: output_relations = [self.outputs] * batch_size - else: output_relations = [rs if type(rs) == list else [rs] for rs in output_relations] - if any([len(rs) == 0 for rs in output_relations]): + if output_relations is None: + current_output_relations = [self.outputs] * batch_size + elif type(output_relations) == str: + current_output_relations = [[output_relations]] * batch_size + elif type(output_relations) == list: + current_output_relations = [rs if type(rs) == list else [rs] for rs in output_relations] + + # Making sure that output relations are well-formed + if any([len(rs) == 0 for rs in current_output_relations]): raise Exception(f"There exists a 0 output relations data-point") - if len(output_relations) != batch_size: - raise Exception(f"Number of output relations ({len(output_relations)}) does not match the batch size ({batch_size})") + if len(current_output_relations) != batch_size: + raise Exception(f"Number of output relations ({len(current_output_relations)}) does not match the batch size ({batch_size})") + + # Process the output_mappings + # the set of output relations for the first data-point in the batch + first_output_relations = current_output_relations[0] + if output_mappings is not None: + # First initialize the `current_output_mappings` to be a deep copy of the existing output mapping + # Later on if there is new temporary output mapping provided that is specific to the batch, we will + # use the temporary output mapping to overwrite the mappings in the existing output mapping + current_output_mappings = {k: v for (k, v) in self.output_mappings.items()} + + # Compute some statistics of output relations to make sure that the input is well-formed + set_of_output_relations = set([r for rs in current_output_relations for r in rs]) + num_output_relations = len(set_of_output_relations) + + # Check if there is only one single output relation + if num_output_relations == 1: + output_relation = first_output_relations[0] + if type(output_mappings) == list: + # the output mappings could be a single list, which would by default be the output mapping + # for the only output relation that is specified + assert len(output_mappings) > 0, f"Expect the `output_mappings` to be non-empty" + current_output_mappings[output_relation] = self._process_one_output_mapping(output_mappings) + elif type(output_mappings) == dict: + # the output_mappings struct could be a dictionary. In this case it could either be empty or contains + # exactly one output mapping which is for the only output relation. + for (rela_name, rela_output_mapping) in output_mappings.items(): + assert rela_name in set_of_output_relations, f"The provided output mapping {rela_name} is not among the set of output relations" + assert len(rela_output_mapping) > 0, f"Expect the output mapping for the `{rela_name}` relation to be non-empty" + current_output_mappings[rela_name] = self._process_one_output_mapping(rela_output_mapping) + else: + assert False, f"Unknown format of `output_mappings` when calling Scallop forward. Expecting list or dict." + else: + # we make sure that this should be the same for every single data-point of the rest of the batch + is_uniform_output = all([rs == first_output_relations for rs in current_output_relations[1:]]) + assert is_uniform_output, f"We expect that the output relations to be the same across all datapoints in a batch" + + # we make sure that the output_mappings is provided for all + assert type(output_mappings) == dict, f"Expect the `output_mappings` variable to be a dict for the batch with multiple output relations" + for (rela_name, rela_output_mapping) in output_mappings.items(): + assert rela_name in set_of_output_relations, f"The provided output mapping {rela_name} is not among the set of output relations" + assert len(rela_output_mapping) > 0, f"Expect the output mapping for the `{rela_name}` relation to be non-empty" + current_output_mappings[rela_name] = self._process_one_output_mapping(rela_output_mapping) + else: + current_output_mappings = self.output_mappings # Check task dispatcher if self.dispatch == "single": @@ -363,20 +422,20 @@ def _call_with_dynamic_ctx( input_tags = [] output_results = [] for task_id in range(batch_size): - (task_input_tags, task_output_results) = self._run_single(task_id, all_inputs, output_relations[task_id]) + (task_input_tags, task_output_results) = self._run_single(task_id, all_inputs, current_output_relations[task_id], current_output_mappings) input_tags.append(task_input_tags) output_results.append(task_output_results) elif self.dispatch == "serial": # Directly dispatch all the inputs to rust - (input_tags, output_results) = self._run_batch(batch_size, all_inputs, output_relations, parallel=False) + (input_tags, output_results) = self._run_batch(batch_size, all_inputs, current_output_relations, current_output_mappings, parallel=False) elif self.dispatch == "parallel": # Dispatch all the inputs to rust and call rayon as parallelism backend - (input_tags, output_results) = self._run_batch(batch_size, all_inputs, output_relations, parallel=True) + (input_tags, output_results) = self._run_batch(batch_size, all_inputs, current_output_relations, current_output_mappings, parallel=True) else: raise Exception(f"Unknown dispatch type `{self.dispatch}`") # Process the output - return self._process_output(batch_size, input_tags, output_results) + return self._process_output(batch_size, input_tags, output_results, current_output_mappings) def _compute_and_check_batch_size(self, inputs: Dict[str, Union[torch_importer.Tensor, List]]) -> int: """ @@ -458,7 +517,7 @@ def _process_input_facts(self, rela, rela_facts, disjunctions) -> List[Tuple]: # Add the facts return facts - def _run_single(self, task_id, all_inputs, output_relations): + def _run_single(self, task_id, all_inputs, output_relations, output_mappings): """ Run a single task (identified by `task_id`) @@ -490,12 +549,12 @@ def _run_single(self, task_id, all_inputs, output_relations): else: cs = [temp_ctx._internal.relation(r) for r in output_relations] # Process the collection to get the output results - output_results = [self._process_single_output(output_relations[i], c) for (i, c) in enumerate(cs)] + output_results = [self._process_single_output(output_relations[i], c, output_mappings) for (i, c) in enumerate(cs)] # Return return (input_tags, output_results) - def _run_batch(self, batch_size, all_inputs, output_relations, parallel: bool): + def _run_batch(self, batch_size, all_inputs, output_relations, output_mappings, parallel: bool): """ Run a batch of tasks """ @@ -503,10 +562,10 @@ def _run_batch(self, batch_size, all_inputs, output_relations, parallel: bool): input_tags, output_results = [], [] for task_id in range(batch_size): input_tags.append(results[task_id][0].input_tags()) - output_results.append([self._process_single_output(output_relations[task_id][i], c) for (i, c) in enumerate(results[task_id])]) + output_results.append([self._process_single_output(output_relations[task_id][i], c, output_mappings) for (i, c) in enumerate(results[task_id])]) return (input_tags, output_results) - def _run_single_static(self, task_id, all_inputs): + def _run_single_static(self, task_id, all_inputs, output_mappings): """ Run a batch of tasks using """ @@ -527,12 +586,12 @@ def _run_single_static(self, task_id, all_inputs): # Get the collection for the target output collections = [temp_ctx.relation(rel_name) for rel_name in self.outputs] - output_results = [self._process_single_output(self.outputs[i], c) for (i, c) in enumerate(collections)] + output_results = [self._process_single_output(self.outputs[i], c, output_mappings) for (i, c) in enumerate(collections)] # Return return (input_tags, output_results) - def _run_batch_static(self, batch_size, all_inputs, parallel): + def _run_batch_static(self, batch_size, all_inputs, output_mappings, parallel): """ Run a batch of tasks using statically compiled module """ @@ -547,29 +606,44 @@ def _run_batch_static(self, batch_size, all_inputs, parallel): input_tags, output_results = [], [] for task_id in range(batch_size): input_tags.append(result[task_id][0]) - output_results.append([self._process_single_output(self.outputs[i], c) for (i, c) in enumerate(result[task_id][1])]) + output_results.append([self._process_single_output(self.outputs[i], c, output_mappings) for (i, c) in enumerate(result[task_id][1])]) # Return return (input_tags, output_results) - def _process_single_output(self, relation_name, internal_collection): + def _process_single_output(self, relation_name, internal_collection, output_mappings): + """ + Given a raw output collection from internal Scallop module, process the output with a given output mapping + """ internal_result_dict = { tup: tag for (tag, tup) in internal_collection } - if relation_name in self.output_mappings and self.output_mappings[relation_name] is not None: - return [internal_result_dict[t] if t in internal_result_dict else None for t in self.output_mappings[relation_name][1]] + if relation_name in output_mappings and output_mappings[relation_name] is not None: + return [internal_result_dict[t] if t in internal_result_dict else None for t in output_mappings[relation_name][1]] else: return internal_result_dict - def _process_output(self, batch_size, input_tags, output_results): - if len(self.outputs) == 1: - return self._process_one_output_wrapper(0, self.outputs[0], batch_size, input_tags, output_results) - elif len(self.outputs) > 1: - return {rel_name: self._process_one_output_wrapper(i, rel_name, batch_size, input_tags, output_results) for (i, rel_name) in enumerate(self.outputs)} - else: - [] + def _process_output(self, batch_size, input_tags, output_results, output_mappings): + """ + Given all the outputs from internal Scallop module, process the outputs + """ - def _process_one_output_wrapper(self, rel_index, rel_name, batch_size, input_tags, output_results): - if self.output_mappings[rel_name] is not None: - (single_element, output_mapping) = self.output_mappings[rel_name] + # First make sure that the outputs are well-defined + if len(self.outputs) == 0: + outputs = list(output_mappings.keys()) + else: + outputs = self.outputs + assert type(outputs) == list, f"Expect outputs to be a list" + assert len(outputs) > 0, f"Expect non-empty output from a forward function; however the current `outputs` is empty" + assert all(type(output_rel) == str for output_rel in outputs), f"Expect each element of output array to be a string (relation name)" + + # Then depending on how many of the output relations, process the output for each output relation + if len(outputs) == 1: + return self._process_one_output_wrapper(0, outputs[0], batch_size, input_tags, output_results, output_mappings) + elif len(outputs) > 1: + return {rel_name: self._process_one_output_wrapper(i, rel_name, batch_size, input_tags, output_results, output_mappings) for (i, rel_name) in enumerate(outputs)} + + def _process_one_output_wrapper(self, rel_index, rel_name, batch_size, input_tags, output_results, output_mappings): + if output_mappings[rel_name] is not None: + (single_element, output_mapping) = output_mappings[rel_name] return self._process_one_output(batch_size, input_tags, [r[rel_index] for r in output_results], single_element, output_mapping) else: return self._process_one_output(batch_size, input_tags, [r[rel_index] for r in output_results], False, None) @@ -743,20 +817,50 @@ def pad_input(l): # mat_i: Input matrix mat_i = torch_importer.torch.stack([torch_importer.torch.stack(pad_input(l)) for l in input_tags]) - # mat_w: Weight matrix - mat_w = self._torch_tensor_apply((torch_importer.torch.zeros(batch_size, num_outputs, num_inputs))) - for (batch_id, task_result) in enumerate(output_batch): # batch_size - for (output_id, output_tag) in enumerate(task_result): # output_size - if output_tag is not None: - deriv = output_tag[1] # The 1-st element of the differentiable result is always the derivative - for (input_id, weight) in deriv: - mat_w[batch_id, output_id, input_id] = weight + # Check whether we want to use sparse jacobian + if self.sparse_jacobian: + # Populate the indices and values to later construct the sparse matrix + indices, values = [], [] + for batch_id, task_result in enumerate(output_batch): # batch_size + for output_id, output_tag in enumerate(task_result): # output_size + if output_tag is not None: + deriv = output_tag[1] # The 1st element of the differentiable result is always the derivative + for (input_id, weight) in deriv: + indices.append([batch_id, output_id, input_id]) + values.append(weight) + + # Convert indices and values into tensors + indices = torch_importer.torch.tensor(indices).t() # Transpose to get the correct shape for sparse_coo_tensor + values = torch_importer.torch.tensor(values) + + # Create the sparse tensor + mat_w = self._torch_tensor_apply( # making sure that the tensor is on the intended device + torch_importer.torch.sparse_coo_tensor(indices, values, (batch_size, num_outputs, num_inputs))) + else: + # mat_w: Weight matrix + mat_w = self._torch_tensor_apply(torch_importer.torch.zeros(batch_size, num_outputs, num_inputs)) + for (batch_id, task_result) in enumerate(output_batch): # batch_size + for (output_id, output_tag) in enumerate(task_result): # output_size + if output_tag is not None: + deriv = output_tag[1] # The 1-st element of the differentiable result is always the derivative + for (input_id, weight) in deriv: + mat_w[batch_id, output_id, input_id] = weight # backward hook retain_graph = self.retain_graph def hook(grad): if mat_i.requires_grad: - mat_f = torch_importer.torch.einsum("ikj,ik->ij", mat_w, grad) + if self.sparse_jacobian: + # An equivalent operation to using einsum under sparse setting + grad_expanded = grad.unsqueeze(-1) + mult = mat_w * grad_expanded + mat_f_sparse = torch_importer.torch.sparse.sum(mult, dim=1) + mat_f = mat_f_sparse.to_dense() + else: + # Chain-rule: multiply the jacobian (mat_w) with the gradient that is back-propagated to Scallop module + mat_f = torch_importer.torch.einsum("ikj,ik->ij", mat_w, grad) + + # Apply backward; potentially retaining graphs mat_i.backward(mat_f, retain_graph=retain_graph) return hook diff --git a/etc/scallopy/scallopy/output_mapping.py b/etc/scallopy/scallopy/output_mapping.py new file mode 100644 index 0000000..8f63e2c --- /dev/null +++ b/etc/scallopy/scallopy/output_mapping.py @@ -0,0 +1,35 @@ +from .utils import _mapping_tuple, _map_entity_tuple_to_str_tuple + +class OutputMapping: + def __init__(self, output_mapping): + if type(output_mapping) == list: + self.is_none = False + self.singleton = False + self.mapping = {0: [_mapping_tuple(t) for t in output_mapping]} + self.shape = (len(self.tuples),) + elif type(output_mapping) == tuple: + self.is_none = False + self.singleton = True + self.mapping = {0: [_mapping_tuple(output_mapping)]} + self.shape = (1,) + elif type(output_mapping) == range: + self.is_none = False + self.singleton = False + self.mapping = {0: [_mapping_tuple(t) for t in list(output_mapping)]} + self.shape = (len(self.tuples),) + elif type(output_mapping) == dict: + num_dim = len(output_mapping) + for i in range(num_dim): + assert i in output_mapping, f"Non-existed dimension {i} in output mapping" + self.mapping = {key: [_mapping_tuple(t) for t in tuples] for (key, tuples) in output_mapping} + self.is_none = False + self.singleton = False + self.dimensional_mapping + elif output_mapping is None: + self.is_none = True + else: + raise Exception(f"Unknown output mapping type `{type(output_mapping)}`") + + def dim(self): + assert not self.is_none, "Cannot obtain dimension from a `None` output mapping" + return len(self.shape) diff --git a/etc/scallopy/src/collection.rs b/etc/scallopy/src/collection.rs index ea1dfc1..bdddff4 100644 --- a/etc/scallopy/src/collection.rs +++ b/etc/scallopy/src/collection.rs @@ -196,6 +196,10 @@ impl Collection { self.collection.len() } + fn __len__(slf: PyRef) -> usize { + slf.collection.len() + } + fn __iter__(slf: PyRef) -> CollectionIterator { CollectionIterator { env: slf.env.clone(), diff --git a/etc/scallopy/src/external_tag.rs b/etc/scallopy/src/external_tag.rs index 7d5e13e..046b511 100644 --- a/etc/scallopy/src/external_tag.rs +++ b/etc/scallopy/src/external_tag.rs @@ -46,7 +46,9 @@ impl ExtTagVec for Vec { fn into_none_prepended_vec(self) -> Vec> { let none: Option> = None; - std::iter::once(Python::with_gil(|py| none.to_object(py))).chain(self.into_iter().map(|v| v.tag)).collect() + std::iter::once(Python::with_gil(|py| none.to_object(py))) + .chain(self.into_iter().map(|v| v.tag)) + .collect() } } diff --git a/etc/scallopy/src/foreign_predicate.rs b/etc/scallopy/src/foreign_predicate.rs index 42b5a2a..ba64da8 100644 --- a/etc/scallopy/src/foreign_predicate.rs +++ b/etc/scallopy/src/foreign_predicate.rs @@ -153,7 +153,10 @@ impl ForeignPredicate for PythonForeignPredicate { // Turn the result back to Scallop values if let Some(result) = maybe_result { let output_tuple_type = self.output_tuple_type(); - let elements: Vec<(&PyAny, &PyAny)> = result.extract(py).expect("Cannot extract into list of elements"); + let elements: Vec<(&PyAny, &PyAny)> = result.extract(py).expect(&format!( + "Cannot extract into list of elements during evaluation of {}", + self.name + )); let internal: Vec<_> = elements .into_iter() .filter_map(|(py_tag, py_tup)| { diff --git a/etc/scallopy/src/provenance.rs b/etc/scallopy/src/provenance.rs index def1dbe..320cb0e 100644 --- a/etc/scallopy/src/provenance.rs +++ b/etc/scallopy/src/provenance.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use pyo3::exceptions::*; use pyo3::prelude::*; use pyo3::types::*; @@ -342,16 +343,22 @@ impl PythonProvenance for diff_top_bottom_k_clauses::DiffTopBottomKClausesProven impl PythonProvenance for diff_top_k_proofs_debug::DiffTopKProofsDebugProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { - let tag_disj_id: (&PyAny, usize, Option) = tag.extract()?; - if let Some(prob) = tag_disj_id.0.extract()? { + let tag_tuple: &PyTuple = tag.extract()?; + let prob: &PyAny = tag_tuple.get_item(0)?; + if let Some(prob) = prob.extract()? { + let tag_disj_id: (&PyAny, usize, Option) = tag.extract()?; let tag: ExtTag = tag_disj_id.0.into(); let id: usize = tag_disj_id.1.into(); - Ok(Some(Self::InputTag { - prob, - id, - external_tag: Some(tag), - exclusion: tag_disj_id.2, - })) + if id == 0 { + Err(PyErr::new::("The input ID to the diff-top-k-proofs-debug provenance cannot be 0; consider changing it to starting from 1.")) + } else { + Ok(Some(Self::InputTag { + prob, + id, + external_tag: Some(tag), + exclusion: tag_disj_id.2, + })) + } } else { Ok(None) } diff --git a/etc/scallopy/tests/forward.py b/etc/scallopy/tests/forward.py index 21d3328..967cd04 100644 --- a/etc/scallopy/tests/forward.py +++ b/etc/scallopy/tests/forward.py @@ -78,6 +78,37 @@ def test_multi_result(self): self.assertEqual(sum_2.shape, (16, 19)) self.assertEqual(mult_2.shape, (16, 100)) + @unittest.expectedFailure + def test_forward_without_output_mapping(self): + forward = self.ctx.forward_function() + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) + _ = forward(digit_1=digit_1, digit_2=digit_2) + + def test_forward_with_output_mapping(self): + forward = self.ctx.forward_function("sum_2") + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) + result_tensor = forward( + digit_1=digit_1, digit_2=digit_2, + output_mappings=[(i,) for i in range(19)]) + self.assertEqual(result_tensor.shape, (16, 19)) + + def test_forward_with_output_mapping_2(self): + forward = self.ctx.forward_function() + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) + results = forward( + digit_1=digit_1, + digit_2=digit_2, + output_relations=[["sum_2", "mult_2"]] * 16, + output_mappings={ + "sum_2": [(i,) for i in range(19)], + "mult_2": [(i,) for i in range(82)], + }) + self.assertEqual(results["sum_2"].shape, (16, 19)) + self.assertEqual(results["mult_2"].shape, (16, 82)) + def test_multi_result_single_dispatch(self): forward = self.ctx.forward_function(output_mappings={"sum_2": list(range(19)), "mult_2": list(range(100))}, dispatch="single") digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) diff --git a/etc/scallopy/tests/sparse_forward.py b/etc/scallopy/tests/sparse_forward.py new file mode 100644 index 0000000..f1776be --- /dev/null +++ b/etc/scallopy/tests/sparse_forward.py @@ -0,0 +1,40 @@ +import torch +import scallopy +import unittest + +class TestSparseDigitForward(unittest.TestCase): + def setUp(self): + self.ctx = scallopy.ScallopContext(provenance="difftopkproofs") + self.ctx.add_relation("digit_1", int, range(10)) + self.ctx.add_relation("digit_2", int, range(10)) + self.ctx.add_rule("sum_2(a + b) = digit_1(a) and digit_2(b)") + self.ctx.add_rule("mult_2(a * b) = digit_1(a) and digit_2(b)") + + def test_backward_with_sparse(self): + loss_fn = torch.nn.BCELoss() + forward = self.ctx.forward_function( + "sum_2", + list(range(19)), + sparse_jacobian=True) + + # Construct the digit + digit_1_base = torch.randn((16, 10), requires_grad=True) + digit_1 = torch.softmax(digit_1_base, dim=1) + digit_2_base = torch.randn((16, 10), requires_grad=True) + digit_2 = torch.softmax(digit_2_base, dim=1) + + # Call scallop and obtain loss + sum_2 = forward(digit_1=digit_1, digit_2=digit_2) + gt = torch.tensor([[1.0] + [0.0] * 18] * 16) + l = loss_fn(sum_2, gt) + + # Ensure that there is no gradient + assert digit_1_base.grad == None + assert digit_2_base.grad == None + + # Perform backward + l.backward() + + # Ensure that there is some gradient + assert any(p != 0.0 for distr in digit_1_base.grad for p in distr) + assert any(p != 0.0 for distr in digit_2_base.grad for p in distr) diff --git a/etc/scallopy/tests/test.py b/etc/scallopy/tests/test.py index 1f05baf..f7ec023 100644 --- a/etc/scallopy/tests/test.py +++ b/etc/scallopy/tests/test.py @@ -8,6 +8,7 @@ from entity import * from failure import * from forward import * +from sparse_forward import * from foreign_attribute import * from foreign_function import * from foreign_predicate import * diff --git a/etc/sclc/Cargo.toml b/etc/sclc/Cargo.toml index ed1af25..350df36 100644 --- a/etc/sclc/Cargo.toml +++ b/etc/sclc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclc-core" -version = "0.2.2" +version = "0.2.4" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/scli/Cargo.toml b/etc/scli/Cargo.toml index 2862a33..1e83a1a 100644 --- a/etc/scli/Cargo.toml +++ b/etc/scli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scli" -version = "0.2.2" +version = "0.2.4" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/sclrepl/Cargo.toml b/etc/sclrepl/Cargo.toml index 72f2def..baa2c9d 100644 --- a/etc/sclrepl/Cargo.toml +++ b/etc/sclrepl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclrepl" -version = "0.2.2" +version = "0.2.4" authors = ["Ziyang Li "] edition = "2018" diff --git a/experiments/pacman_maze/run.py b/experiments/pacman_maze/run.py index e2f0e83..1f4b1fa 100644 --- a/experiments/pacman_maze/run.py +++ b/experiments/pacman_maze/run.py @@ -343,7 +343,7 @@ def run(self): parser.add_argument("--show-run", action="store_true") parser.add_argument("--show-train-run", action="store_true") parser.add_argument("--show-test-run", action="store_true") - parser.add_argument("--show-run-interval", type=int, default=0.001) + parser.add_argument("--show-run-interval", type=float, default=0.001) parser.add_argument("--window-size", type=int, default=200) parser.add_argument("--overlay-prediction", action="store_true") parser.add_argument("--easy", action="store_true") diff --git a/experiments/pacman_maze/run_random.py b/experiments/pacman_maze/run_random.py index 45d75b7..9923bbe 100644 --- a/experiments/pacman_maze/run_random.py +++ b/experiments/pacman_maze/run_random.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser from tqdm import tqdm +import cv2 import random from arena import AvoidingArena @@ -19,6 +20,8 @@ def test_random_model(): for episode_i in iterator: state = arena.reset() for _ in range(args.num_steps): + if args.show_run: + show_image(arena.render()) action = model(state) _, done, reward, _ = arena.step(action) if done: @@ -30,6 +33,11 @@ def test_random_model(): success_rate = (success / (episode_i + 1)) * 100.0 iterator.set_description(f"[Test] {success}/{episode_i + 1} ({success_rate:.2f}%)") +def show_image(raw_image): + cv2.namedWindow("Current Arena", cv2.WINDOW_NORMAL) + cv2.resizeWindow("Current Arena", args.window_size, args.window_size) + cv2.imshow("Current Arena", raw_image) + cv2.waitKey(int(args.show_run_interval * 1000)) if __name__ == "__main__": parser = ArgumentParser() @@ -40,6 +48,11 @@ def test_random_model(): parser.add_argument("--num-enemies", type=int, default=5) parser.add_argument("--num-episodes", type=int, default=1000) parser.add_argument("--num-steps", type=int, default=30) + parser.add_argument("--show-run", action="store_true") + parser.add_argument("--show-run-interval", type=float, default=0.001) + parser.add_argument("--window-size", type=int, default=200) args = parser.parse_args() + args.show_run_interval = max(0.001, args.show_run_interval) # Minimum 1ms + test_random_model() diff --git a/experiments/pacman_maze/scl/arena_answer.scl b/experiments/pacman_maze/scl/arena_answer.scl new file mode 100644 index 0000000..3c0268b --- /dev/null +++ b/experiments/pacman_maze/scl/arena_answer.scl @@ -0,0 +1,54 @@ +// Static input facts +type grid_node(x: usize, y: usize) + +// Input from neural networks +type actor(x: usize, y: usize) +type goal(x: usize, y: usize) +type enemy(x: usize, y: usize) + +// Possible actions to take +type Action = UP | RIGHT | DOWN | LEFT + +// =========== YOUR CODE START HERE =========== + +// ** Problem 1: safe_node ** +// (x, y) is a safe node if it is a grid node and does not contain an enemy +type safe_node(x: usize, y: usize) +rel safe_node(x, y) = grid_node(x, y) and not enemy(x, y) + +// ** Problem 2: edge ** +// There is an (safe) edge between safe nodes (x1, y1) and (x2, y2) if +// taking the action `a` can move the actor from (x1, y1) to (x2, y2) +type edge(x1: usize, y1: usize, x2: usize, y2: usize, a: Action) +rel edge(x1, y1, x2, y1 + 1, UP) = node(x1, y1) and node(x2, y1 + 1) +rel edge(x1, y1, x1 + 1, y2, RIGHT) = node(x1, y1) and node(x1 + 1, y2) +rel edge(x1, y1, x2, y1 - 1, DOWN) = node(x1, y1) and node(x2, y1 - 1) +rel edge(x1, y1, x1 - 1, y2, LEFT) = node(x1, y1) and node(x1 - 1, y2) + +// ** Problem 3: path ** +// There is a (safe) path between safe nodes (x1, y1) and (x2, y2) if +// there is a series of safe edges connecting the two nodes. +// Note that self-path is also a safe path. +type path(x1: usize, y1: usize, x2: usize, y2: usize) +rel path(x, y, x, y) = safe_node(x, y) +rel path(x1, y1, x2, y2) = edge(x1, y1, x2, y2) +rel path(x1, y1, x3, y3) = path(x1, y1, x2, y2) and edge(x2, y2, x3, y3) + +// ** Problem 4: next_position ** +// Given the current actor position, taking the action `a` would move the +// actor to the position (x, y) +type next_position(a: Action, x: usize, y: usize) +rel next_position(a, x1, y1) = actor(x, y) and edge(x, y, xp, yp, a) + +// ** Problem 5: next_position ** +// We pick the action `a` as the next action if, after moving to the next +// position with `a`, we have a safe path from the next position to the goal +type next_action(a: Action) +rel next_action(a) = next_position(a, xn, yn) and goal(xg, yg) and path(xn, yn, xg, yg) + +// =========== YOUR CODE END HERE =========== + +// Constraint violation; please keep these as is +rel too_many_goal() = n := count(x, y: goal(x, y)), n > 1 +rel too_many_enemy() = n := count(x, y: enemy(x, y)), n > 5 +rel violation() = too_many_goal() or too_many_enemy() diff --git a/experiments/pacman_maze/scl/arena_todo.scl b/experiments/pacman_maze/scl/arena_todo.scl new file mode 100644 index 0000000..e75f550 --- /dev/null +++ b/experiments/pacman_maze/scl/arena_todo.scl @@ -0,0 +1,49 @@ +// Static input facts +type grid_node(x: usize, y: usize) + +// Input from neural networks +type actor(x: usize, y: usize) +type goal(x: usize, y: usize) +type enemy(x: usize, y: usize) + +// Possible actions to take +type Action = UP | RIGHT | DOWN | LEFT + +// =========== YOUR CODE START HERE =========== + +// ** Problem 1: safe_node ** +// (x, y) is a safe node if it is a grid node and does not contain an enemy +type safe_node(x: usize, y: usize) +/* YOUR CODE HERE */ + +// ** Problem 2: edge ** +// There is an (safe) edge between safe nodes (x1, y1) and (x2, y2) if +// taking the action `a` can move the actor from (x1, y1) to (x2, y2) +type edge(x1: usize, y1: usize, x2: usize, y2: usize, a: Action) +/* YOUR CODE HERE */ + +// ** Problem 3: path ** +// There is a (safe) path between safe nodes (x1, y1) and (x2, y2) if +// there is a series of safe edges connecting the two nodes. +// Note that self-path is also a safe path. +type path(x1: usize, y1: usize, x2: usize, y2: usize) +/* YOUR CODE HERE */ + +// ** Problem 4: next_position ** +// Given the current actor position, taking the action `a` would move the +// actor to the position (x, y) +type next_position(a: Action, x: usize, y: usize) +/* YOUR CODE HERE */ + +// ** Problem 5: next_position ** +// We pick the action `a` as the next action if, after moving to the next +// position with `a`, we have a safe path from the next position to the goal +type next_action(a: Action) +/* YOUR CODE HERE */ + +// =========== YOUR CODE END HERE =========== + +// Constraint violation; please keep these as is +rel too_many_goal() = n := count(x, y: goal(x, y)), n > 1 +rel too_many_enemy() = n := count(x, y: enemy(x, y)), n > 5 +rel violation() = too_many_goal() or too_many_enemy() diff --git a/experiments/pacman_maze/scl/arena_w_constraint.scl b/experiments/pacman_maze/scl/arena_w_constraint.scl index 42e72b7..a661d10 100644 --- a/experiments/pacman_maze/scl/arena_w_constraint.scl +++ b/experiments/pacman_maze/scl/arena_w_constraint.scl @@ -1,3 +1,5 @@ +type Action = UP | RIGHT | DOWN | LEFT + // Input from neural networks type grid_node(x: usize, y: usize) type curr_position(x: usize, y: usize) @@ -6,10 +8,10 @@ type is_enemy(x: usize, y: usize) // Basic connectivity rel node(x, y) = grid_node(x, y), not is_enemy(x, y) -rel edge(x, y, x, yp, 0) = node(x, y), node(x, yp), yp == y + 1 // Up -rel edge(x, y, xp, y, 1) = node(x, y), node(xp, y), xp == x + 1 // Right -rel edge(x, y, x, yp, 2) = node(x, y), node(x, yp), yp == y - 1 // Down -rel edge(x, y, xp, y, 3) = node(x, y), node(xp, y), xp == x - 1 // Left +rel edge(x, y, x, yp, UP) = node(x, y), node(x, yp), yp == y + 1 // Up +rel edge(x, y, xp, y, RIGHT) = node(x, y), node(xp, y), xp == x + 1 // Right +rel edge(x, y, x, yp, DOWN) = node(x, y), node(x, yp), yp == y - 1 // Down +rel edge(x, y, xp, y, LEFT) = node(x, y), node(xp, y), xp == x - 1 // Left // Path for connectivity; will condition on no enemy on the path rel path(x, y, x, y) = node(x, y)