diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b92507e..99afb98 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,30 +1,24 @@ -name: Rust +name: Cargo Build & Test on: push: - branches: [ "master" ] + branches: [ master ] pull_request: - branches: [ "master" ] + branches: [ master ] -env: + +env: CARGO_TERM_COLOR: always jobs: - build: - + build_and_test: + name: Rust project - latest runs-on: ubuntu-latest - + strategy: + matrix: + toolchain: + - nightly steps: - - uses: actions/checkout@v3 - - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ steps.component.outputs.toolchain }} - override: true - - id: component - uses: actions-rs/components-nightly@v1 - with: - component: clippy - - name: Build - run: cargo build --verbose - - name: Run tests - run: cargo test --verbose + - uses: actions/checkout@v3 + - run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} + - run: cargo test --verbose --release diff --git a/core/src/common/expr.rs b/core/src/common/expr.rs index 3a4db46..7ffd270 100644 --- a/core/src/common/expr.rs +++ b/core/src/common/expr.rs @@ -1,3 +1,5 @@ +use std::iter::FromIterator; + use super::binary_op::BinaryOp; use super::tuple_access::TupleAccessor; use super::unary_op::UnaryOp; @@ -232,6 +234,12 @@ where } } +impl FromIterator for Expr { + fn from_iter>(iter: I) -> Self { + Expr::Tuple(iter.into_iter().collect()) + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct BinaryExpr { pub op: BinaryOp, diff --git a/core/src/common/foreign_predicates/mod.rs b/core/src/common/foreign_predicates/mod.rs index 616774e..4183c05 100644 --- a/core/src/common/foreign_predicates/mod.rs +++ b/core/src/common/foreign_predicates/mod.rs @@ -11,18 +11,18 @@ use super::value_type::*; mod float_eq; mod range; -mod string_chars; mod soft_cmp; mod soft_eq; mod soft_gt; mod soft_lt; mod soft_neq; +mod string_chars; pub use float_eq::*; pub use range::*; -pub use string_chars::*; pub use soft_cmp::*; pub use soft_eq::*; pub use soft_gt::*; pub use soft_lt::*; pub use soft_neq::*; +pub use string_chars::*; diff --git a/core/src/compiler/back/ast.rs b/core/src/compiler/back/ast.rs index 80ed32a..b4e707b 100644 --- a/core/src/compiler/back/ast.rs +++ b/core/src/compiler/back/ast.rs @@ -92,7 +92,7 @@ pub struct Rule { impl Rule { pub fn head_predicate(&self) -> &String { - &self.head.predicate + &self.head.predicate() } pub fn body_literals(&self) -> impl Iterator { @@ -105,21 +105,95 @@ impl Rule { } #[derive(Clone, Debug, PartialEq)] -pub struct Head { - pub predicate: String, - pub args: Vec, +pub enum Head { + /// A simple atom as the head + Atom(Atom), + + /// A disjunction of atoms as the head; all atoms should have the same predicate + Disjunction(Vec), } impl Head { - pub fn new(predicate: String, args: Vec) -> Self { - Self { predicate, args } + pub fn atom(predicate: String, args: Vec) -> Self { + Self::Atom(Atom::new(predicate, args)) } - pub fn variable_args(&self) -> impl Iterator { - self.args.iter().filter_map(|a| match a { - Term::Variable(v) => Some(v), + pub fn predicate(&self) -> &String { + match self { + Self::Atom(a) => &a.predicate, + Self::Disjunction(disj) => &disj[0].predicate, + } + } + + pub fn get_atom(&self) -> Option<&Atom> { + match self { + Self::Atom(a) => Some(a), _ => None, - }) + } + } + + pub fn variable_args(&self) -> Vec<&Variable> { + match self { + Self::Atom(a) => a.variable_args().collect(), + Self::Disjunction(disj) => disj.iter().flat_map(|a| a.variable_args()).collect(), + } + } + + /// Substitute the atom's arguments with the given term rewrite function + pub fn substitute Term + Copy>(&self, f: F) -> Self { + match self { + Self::Atom(a) => Self::Atom(a.substitute(f)), + Self::Disjunction(d) => Self::Disjunction(d.iter().map(|a| a.substitute(f)).collect()), + } + } + + /// Get the variable patterns of the head + /// + /// Atomic head has only one pattern; + /// Disjunctive head could have multiple patterns + pub fn has_multiple_patterns(&self) -> bool { + match self { + Self::Atom(_) => { + // Atomic head has only one pattern + false + }, + Self::Disjunction(disj) => { + // Extract the pattern of the first atom in the disjunction + let first_pattern = disj[0] + .args + .iter() + .map(|t| match t { + Term::Variable(v) => v.name.clone(), + Term::Constant(_) => String::new(), + }) + .collect::>(); + + // Check if the first pattern is satisfied by all other atoms + for a in disj.iter().skip(1) { + let satisfies_pattern = a.args + .iter() + .enumerate() + .all(|(i, t)| { + if let Some(p) = first_pattern.get(i) { + match t { + Term::Variable(v) => p == &v.name, + Term::Constant(_) => p.is_empty(), + } + } else { + false + } + }); + + // If not satisfied, then the head has multiple patterns + if !satisfies_pattern { + return true; + } + } + + // If all atoms satisfy the first pattern, then the head has only one pattern + false + }, + } } } @@ -130,7 +204,7 @@ pub struct Conjunction { } /// A term is the argument of a literal -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum Term { Variable(Variable), Constant(Constant), @@ -309,6 +383,14 @@ impl Atom { self.args.iter().any(|a| a.is_constant()) } + /// Get the constant arguments + pub fn constant_args(&self) -> impl Iterator { + self.args.iter().filter_map(|a| match a { + Term::Constant(c) => Some(c), + _ => None, + }) + } + /// Create a partition of the atom's arguments into constant and variable pub fn const_var_partition(&self) -> (Vec<(usize, &Constant)>, Vec<(usize, &Variable)>) { let (constants, variables): (Vec<_>, Vec<_>) = self.args.iter().enumerate().partition(|(_, t)| t.is_constant()); @@ -328,6 +410,14 @@ impl Atom { .collect(); (constants, variables) } + + /// Substitute the atom's arguments with the given term rewrite function + pub fn substitute Term>(&self, f: F) -> Self { + Self { + predicate: self.predicate.clone(), + args: self.args.iter().map(|a| f(a)).collect(), + } + } } #[derive(Clone, Debug, PartialEq)] diff --git a/core/src/compiler/back/b2r.rs b/core/src/compiler/back/b2r.rs index 07d17a7..a1d3911 100644 --- a/core/src/compiler/back/b2r.rs +++ b/core/src/compiler/back/b2r.rs @@ -34,9 +34,13 @@ struct NegativeDataflow { dataflow: ram::Dataflow, } +/// Property of a dataflow #[derive(Default, Clone)] struct DataflowProp { + /// Whether the dataflow needs to be sorted need_sorted: bool, + + /// Whether the dataflow is negative is_negative: bool, } @@ -270,7 +274,7 @@ impl Program { let output = self.outputs.get(pred).cloned().unwrap_or(OutputOption::Hidden); // Check immutability, i.e., the relation is not updated by rules - let immutable = self.rules.iter().find_position(|r| &r.head.predicate == pred).is_none(); + let immutable = self.rules.iter().find_position(|r| r.head.predicate() == pred).is_none(); // The Final Relation let ram_relation = ram::Relation { @@ -300,31 +304,85 @@ impl Program { } fn plan_to_ram_update(&self, ctx: &mut B2RContext, head: &Head, plan: &Plan) -> ram::Update { - let (head_goal, need_projection) = self.head_variable_tuple(head); - let subgoal = head_goal.dedup(); + // Check if the dataflow needs projection and update the dataflow + let dataflow = match head { + Head::Atom(head_atom) => { + let (head_goal, need_projection) = self.head_atom_variable_tuple(head_atom); + let subgoal = head_goal.dedup(); + + // Generate the dataflow + let dataflow = self.plan_to_ram_dataflow(ctx, &subgoal, plan, false.into()); + + // Project the dataflow if needed + let dataflow = if need_projection { + dataflow.project(self.projection_to_atom_head(&subgoal, head_atom)) + } else if head_goal != subgoal { + dataflow.project(subgoal.projection(&head_goal)) + } else { + dataflow + }; - // Generate the dataflow - let dataflow = self.plan_to_ram_dataflow(ctx, &subgoal, plan, false.into()); + // Check if the head predicate is a magic-set; if so we wrap an overwrite_one dataflow around + // NOTE: only head atom predicate can be magic-set predicate + let dataflow = if self.is_magic_set_predicate(&head.predicate()) == Some(true) { + dataflow.overwrite_one() + } else { + dataflow + }; - // Check if the dataflow needs projection and update the dataflow - let dataflow = if need_projection { - ram::Dataflow::project(dataflow, self.projection_to_head(&subgoal, head)) - } else if head_goal != subgoal { - ram::Dataflow::project(dataflow, subgoal.projection(&head_goal)) - } else { - dataflow - }; + dataflow + } + Head::Disjunction(head_atoms) => { + if !head.has_multiple_patterns() { + let head_var_goal = VariableTuple::from_vars(head_atoms[0].variable_args().cloned(), false); - // Check if the head predicate is a magic-set; if so we wrap an overwrite_one dataflow around - let dataflow = if self.is_magic_set_predicate(&head.predicate) == Some(true) { - ram::Dataflow::overwrite_one(dataflow) - } else { - dataflow + // Generate the sub-dataflow + let sub_dataflow = self.plan_to_ram_dataflow(ctx, &head_var_goal, plan, false.into()); + + // Get all the constants in the head atoms + let constants: Vec<_> = head_atoms + .iter() + .map(|head_atom| { + use std::iter::FromIterator; + Tuple::from_iter(head_atom.constant_args().cloned()) + }) + .collect(); + + // Disjunction dataflow + let disj_dataflow = sub_dataflow.exclusion(constants); + + // Projection + let (mut var_counter, mut const_counter) = (0, 0); + let projection: Expr = head_atoms[0] + .args + .iter() + .map(|arg| { + match arg { + Term::Variable(_) => { + let result = Expr::access((0, var_counter)); + var_counter += 1; + result + } + Term::Constant(_) => { + let result = Expr::access((1, const_counter)); + const_counter += 1; + result + } + } + }) + .collect(); + + // Wrap the disjunction dataflow with a projection + disj_dataflow.project(projection) + } else { + unimplemented!("Disjunction with more than one pattern is not supported yet.") + } + } }; // Return the update ram::Update { - target: head.predicate.clone(), + target: head.predicate().clone(), dataflow, } } @@ -908,11 +966,11 @@ impl Program { } } - fn head_variable_tuple(&self, head: &Head) -> (VariableTuple, bool) { + fn head_atom_variable_tuple(&self, head: &Atom) -> (VariableTuple, bool) { let rel = self.relation_of_predicate(&head.predicate).unwrap(); if let Some(agg_attr) = rel.attributes.aggregate_body_attr() { // For an aggregate sub-relation - let head_args = head.variable_args().cloned().collect::>(); + let head_args = head.variable_args().into_iter().cloned().collect::>(); let num_group_by = agg_attr.num_group_by_vars; let num_args = agg_attr.num_arg_vars; @@ -944,7 +1002,7 @@ impl Program { (var_tuple, false) } else if let Some(agg_group_by_attr) = rel.attributes.aggregate_group_by_attr() { let num_group_by = agg_group_by_attr.num_join_group_by_vars; - let var_args = head.variable_args().cloned().collect::>(); + let var_args = head.variable_args().into_iter().cloned().collect::>(); let joined = VariableTuple::from_vars((&var_args[..num_group_by]).into_iter().cloned(), true); let others = VariableTuple::from_vars((&var_args[num_group_by..]).into_iter().cloned(), true); @@ -956,24 +1014,23 @@ impl Program { let top = head .args .iter() - .filter_map(|arg| { - let v = match arg { - Term::Variable(v) => v.clone(), - _ => { - need_projection = true; - return None; - } - }; - Some(VariableTuple::Value(v)) + .filter_map(|arg| match arg { + Term::Variable(v) => { + Some(VariableTuple::Value(v.clone())) + }, + _ => { + need_projection = true; + None + } }) .collect(); (VariableTuple::Tuple(top), need_projection) } } - pub fn projection_to_head(&self, var_tuple: &VariableTuple, head: &Head) -> Expr { + pub fn projection_to_atom_head(&self, var_tuple: &VariableTuple, head_atom: &Atom) -> Expr { Expr::Tuple( - head + head_atom .args .iter() .map(|a| match a { diff --git a/core/src/compiler/back/optimizations/constant_propagation.rs b/core/src/compiler/back/optimizations/constant_propagation.rs index f22ec16..6877f3a 100644 --- a/core/src/compiler/back/optimizations/constant_propagation.rs +++ b/core/src/compiler/back/optimizations/constant_propagation.rs @@ -159,10 +159,7 @@ pub fn constant_prop(rule: &mut Rule) { } // Apply substitution to the head - let new_head = Head { - predicate: rule.head.predicate.clone(), - args: rule.head.args.iter().map(substitute_term).collect(), - }; + let new_head = rule.head.substitute(substitute_term); // Update the rule rule.body.args = new_literals; diff --git a/core/src/compiler/back/optimizations/demand_transform.rs b/core/src/compiler/back/optimizations/demand_transform.rs index 2c94728..98a4f65 100644 --- a/core/src/compiler/back/optimizations/demand_transform.rs +++ b/core/src/compiler/back/optimizations/demand_transform.rs @@ -85,7 +85,7 @@ fn transform_on_demand_rule(rule: &Rule, adornment: &Adornment) -> Rule { // Create demand atom let demand_atom = Atom { predicate: adornment.demand_predicate.clone(), - args: adornment.pattern.get_bounded_args(&rule.head.args), + args: adornment.pattern.get_bounded_args(&rule.head.get_atom().unwrap().args), }; // Append it to the new rule @@ -169,10 +169,7 @@ fn generate_demand_rule(base: &Vec, goal: &Atom, adm: &Adornment) -> Op } else { let rule = Rule { attributes: Attributes::new(), - head: Head { - predicate: adm.demand_predicate.clone(), - args: adm.pattern.get_bounded_args(&goal.args), - }, + head: Head::atom(adm.demand_predicate.clone(), adm.pattern.get_bounded_args(&goal.args)), body: Conjunction { args: base }, }; Some(rule) diff --git a/core/src/compiler/back/optimizations/empty_rule_to_fact.rs b/core/src/compiler/back/optimizations/empty_rule_to_fact.rs index 2260e29..8a4b010 100644 --- a/core/src/compiler/back/optimizations/empty_rule_to_fact.rs +++ b/core/src/compiler/back/optimizations/empty_rule_to_fact.rs @@ -5,25 +5,32 @@ use super::super::*; pub fn empty_rule_to_fact(rules: &mut Vec, facts: &mut Vec) { rules.retain(|rule| { if rule.body.args.is_empty() { - // Create fact - let fact = Fact { - tag: DynamicInputTag::None, - predicate: rule.head.predicate.clone(), - args: rule - .head - .args - .iter() - .map(|arg| match arg { - Term::Constant(c) => c.clone(), - Term::Variable(v) => panic!("[Internal Error] Invalid head variable `{}` in an empty rule", v.name), - }) - .collect(), - }; + match &rule.head { + Head::Atom(head_atom) => { + // Create fact + let fact = Fact { + tag: DynamicInputTag::None, + predicate: head_atom.predicate.clone(), + args: head_atom + .args + .iter() + .map(|arg| match arg { + Term::Constant(c) => c.clone(), + Term::Variable(v) => panic!("[Internal Error] Invalid head variable `{}` in an empty rule", v.name), + }) + .collect(), + }; - // Add the fact to the set of facts - facts.push(fact); + // Add the fact to the set of facts + facts.push(fact); - false + false + } + Head::Disjunction(_) => { + // TODO: Handle disjunctions + true + } + } } else { true } diff --git a/core/src/compiler/back/optimizations/equality_propagation.rs b/core/src/compiler/back/optimizations/equality_propagation.rs index 097354b..137e51c 100644 --- a/core/src/compiler/back/optimizations/equality_propagation.rs +++ b/core/src/compiler/back/optimizations/equality_propagation.rs @@ -125,10 +125,7 @@ pub fn propagate_equality(rule: &mut Rule) { } // Apply substitution to the head - let new_head = Head { - predicate: rule.head.predicate.clone(), - args: rule.head.args.iter().map(substitute_term).collect(), - }; + let new_head = rule.head.substitute(substitute_term); // Update the rule into this new rule *rule = Rule { diff --git a/core/src/compiler/back/pretty.rs b/core/src/compiler/back/pretty.rs index 7d6eccf..747f3b8 100644 --- a/core/src/compiler/back/pretty.rs +++ b/core/src/compiler/back/pretty.rs @@ -121,14 +121,19 @@ impl Display for Rule { impl Display for Head { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - f.write_fmt(format_args!("{}(", self.predicate))?; - for (i, arg) in self.args.iter().enumerate() { - f.write_fmt(format_args!("{}", arg))?; - if i < self.args.len() - 1 { - f.write_str(", ")?; + match self { + Self::Atom(a) => a.fmt(f), + Self::Disjunction(atoms) => { + f.write_str("{")?; + for (i, atom) in atoms.iter().enumerate() { + if i > 0 { + f.write_str("; ")?; + } + atom.fmt(f)?; + } + f.write_str("}") } } - f.write_str(")") } } diff --git a/core/src/compiler/back/query_plan.rs b/core/src/compiler/back/query_plan.rs index e948c6a..60042f9 100644 --- a/core/src/compiler/back/query_plan.rs +++ b/core/src/compiler/back/query_plan.rs @@ -30,7 +30,7 @@ impl<'a> QueryPlanContext<'a> { ) -> Self { // First create an empty context let mut ctx = Self { - head_vars: rule.head.variable_args().cloned().collect(), + head_vars: rule.head.variable_args().into_iter().cloned().collect(), reduces: vec![], pos_atoms: vec![], neg_atoms: vec![], diff --git a/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs b/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs index 70b1a70..3570815 100644 --- a/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs +++ b/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs @@ -38,13 +38,16 @@ impl BoundnessAnalysis { if !*inferred { // Make sure the demand attribute is affecting boundness analysis, // through some of the head expressions being bounded - let bounded_exprs = if let Some((pattern, _)) = demand_attrs.get(rule.head().predicate()) { - rule - .head() - .iter_arguments() - .zip(pattern.chars()) - .filter_map(|(a, b)| if b == 'b' { Some(a.clone()) } else { None }) - .collect() + let bounded_exprs = if let Some(head_atom) = rule.head().atom() { + if let Some((pattern, _)) = demand_attrs.get(head_atom.predicate()) { + head_atom + .iter_arguments() + .zip(pattern.chars()) + .filter_map(|(a, b)| if b == 'b' { Some(a.clone()) } else { None }) + .collect() + } else { + vec![] + } } else { vec![] }; diff --git a/core/src/compiler/front/analyzers/boundness/context.rs b/core/src/compiler/front/analyzers/boundness/context.rs index 61bcc33..9573844 100644 --- a/core/src/compiler/front/analyzers/boundness/context.rs +++ b/core/src/compiler/front/analyzers/boundness/context.rs @@ -13,7 +13,7 @@ pub struct RuleContext { impl RuleContext { pub fn from_rule(rule: &Rule) -> Self { - let head_vars = collect_vars_in_atom(rule.head()); + let head_vars = collect_vars_in_head(rule.head()); let body = DisjunctionContext::from_formula(rule.body()); Self { head_vars, body } } @@ -291,6 +291,13 @@ impl AggregationContext { } } +fn collect_vars_in_head(head: &RuleHead) -> Vec<(String, Loc)> { + match &head.node { + RuleHeadNode::Atom(atom) => collect_vars_in_atom(atom), + RuleHeadNode::Disjunction(d) => d.iter().map(collect_vars_in_atom).flatten().collect(), + } +} + fn collect_vars_in_atom(atom: &Atom) -> Vec<(String, Loc)> { atom.iter_arguments().map(collect_vars_in_expr).flatten().collect() } diff --git a/core/src/compiler/front/analyzers/demand_attr.rs b/core/src/compiler/front/analyzers/demand_attr.rs index 349c09e..ff652c5 100644 --- a/core/src/compiler/front/analyzers/demand_attr.rs +++ b/core/src/compiler/front/analyzers/demand_attr.rs @@ -6,6 +6,7 @@ use super::type_inference; #[derive(Clone, Debug)] pub struct DemandAttributeAnalysis { pub demand_attrs: HashMap, + pub disjunctive_predicates: HashSet, pub errors: Vec, } @@ -13,6 +14,7 @@ impl DemandAttributeAnalysis { pub fn new() -> Self { Self { demand_attrs: HashMap::new(), + disjunctive_predicates: HashSet::new(), errors: Vec::new(), } } @@ -35,7 +37,27 @@ impl DemandAttributeAnalysis { } } + pub fn set_disjunctive(&mut self, pred: &String, loc: &AstNodeLocation) { + if self.demand_attrs.contains_key(pred) { + self.errors.push(DemandAttributeError::DisjunctivePredicateWithDemandAttribute { + pred: pred.clone(), + loc: loc.clone(), + }); + } else { + self.disjunctive_predicates.insert(pred.clone()); + } + } + pub fn process_attribute(&mut self, pred: &str, attr: &Attribute) { + // Check if the predicate occurs in a disjunctive head + if self.disjunctive_predicates.contains(pred) { + self.errors.push(DemandAttributeError::DisjunctivePredicateWithDemandAttribute { + pred: pred.to_string(), + loc: attr.location().clone(), + }); + } + + // Check the pattern if attr.name() == "demand" { if attr.num_pos_args() == 1 { let value = attr.pos_arg(0).unwrap(); @@ -89,7 +111,17 @@ impl NodeVisitor for DemandAttributeAnalysis { } fn visit_rule_decl(&mut self, rule_decl: &ast::RuleDecl) { - self.process_attributes(rule_decl.rule().head().predicate(), rule_decl.attributes()); + if rule_decl.rule().head().is_disjunction() { + for predicate in rule_decl.rule().head().iter_predicates() { + self.set_disjunctive(predicate, rule_decl.rule().head().location()); + return; // early stopping because this is an error + } + } + + // Otherwise, we add the demand attribute + if let Some(atom) = rule_decl.rule().head().atom() { + self.process_attributes(atom.predicate(), rule_decl.attributes()); + } } } @@ -121,6 +153,10 @@ pub enum DemandAttributeError { InvalidPattern { loc: AstNodeLocation, }, + DisjunctivePredicateWithDemandAttribute { + pred: String, + loc: AstNodeLocation, + }, } impl FrontCompileErrorTrait for DemandAttributeError { @@ -173,6 +209,9 @@ impl FrontCompileErrorTrait for DemandAttributeError { Self::InvalidPattern { loc } => { format!("Invalid demand pattern\n{}", loc.report(src)) } + Self::DisjunctivePredicateWithDemandAttribute { pred, loc } => { + format!("The predicate `{}` being annotated by `demand` but occurs in a disjunctive rule head\n{}", pred, loc.report(src)) + } } } } diff --git a/core/src/compiler/front/analyzers/head_relation.rs b/core/src/compiler/front/analyzers/head_relation.rs index 520450a..de605a9 100644 --- a/core/src/compiler/front/analyzers/head_relation.rs +++ b/core/src/compiler/front/analyzers/head_relation.rs @@ -53,7 +53,9 @@ impl NodeVisitor for HeadRelationAnalysis { } fn visit_rule(&mut self, rd: &ast::Rule) { - self.declared_relations.insert(rd.head().predicate().to_string()); + for predicate in rd.head().iter_predicates() { + self.declared_relations.insert(predicate.to_string()); + } } fn visit_query(&mut self, qd: &ast::Query) { diff --git a/core/src/compiler/front/analyzers/hidden_relation.rs b/core/src/compiler/front/analyzers/hidden_relation.rs index 41a41ed..7efe16e 100644 --- a/core/src/compiler/front/analyzers/hidden_relation.rs +++ b/core/src/compiler/front/analyzers/hidden_relation.rs @@ -41,6 +41,8 @@ impl NodeVisitor for HiddenRelationAnalysis { } fn visit_rule_decl(&mut self, rule_decl: &RuleDecl) { - self.process_attributes(rule_decl.rule().head().predicate(), rule_decl.attributes()) + for predicate in rule_decl.rule().head().iter_predicates() { + self.process_attributes(predicate, rule_decl.attributes()) + } } } diff --git a/core/src/compiler/front/analyzers/type_inference/type_inference.rs b/core/src/compiler/front/analyzers/type_inference/type_inference.rs index f350fcc..b6f6c95 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_inference.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_inference.rs @@ -516,15 +516,15 @@ impl NodeVisitor for TypeInference { } fn visit_rule(&mut self, rule: &Rule) { - let pred = rule.head().predicate(); - - // Check if the relation is a foreign predicate - if self.foreign_predicate_type_registry.contains_predicate(pred) { - self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { - pred: pred.to_string(), - loc: rule.location().clone(), - }); - return; + for pred in rule.head().iter_predicates() { + // Check if a head predicate is a foreign predicate + if self.foreign_predicate_type_registry.contains_predicate(pred) { + self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { + pred: pred.to_string(), + loc: rule.location().clone(), + }); + return; + } } // Otherwise, create a rule inference context diff --git a/core/src/compiler/front/ast/expr.rs b/core/src/compiler/front/ast/expr.rs index 9303844..e6296ab 100644 --- a/core/src/compiler/front/ast/expr.rs +++ b/core/src/compiler/front/ast/expr.rs @@ -29,6 +29,11 @@ impl Expr { })) } + /// Create a constant boolean expression + pub fn boolean(b: bool) -> Self { + Self::Constant(ConstantNode::Boolean(b).into()) + } + /// Create an expression which is a constant of boolean true value pub fn boolean_true() -> Self { Self::Constant(ConstantNode::Boolean(true).into()) diff --git a/core/src/compiler/front/ast/formula.rs b/core/src/compiler/front/ast/formula.rs index 0bba8ad..7b829a5 100644 --- a/core/src/compiler/front/ast/formula.rs +++ b/core/src/compiler/front/ast/formula.rs @@ -337,6 +337,7 @@ impl ReduceOperator { #[derive(Clone, Debug, PartialEq)] #[doc(hidden)] pub struct ForallExistsReduceNode { + pub negate: bool, pub operator: ReduceOperator, pub bindings: Vec, pub body: Box, @@ -348,6 +349,10 @@ pub struct ForallExistsReduceNode { pub type ForallExistsReduce = AstNode; impl ForallExistsReduce { + pub fn is_negated(&self) -> bool { + self.node.negate + } + pub fn operator(&self) -> &ReduceOperator { &self.node.operator } diff --git a/core/src/compiler/front/ast/relation_decl.rs b/core/src/compiler/front/ast/relation_decl.rs index f9b44e2..d573bb7 100644 --- a/core/src/compiler/front/ast/relation_decl.rs +++ b/core/src/compiler/front/ast/relation_decl.rs @@ -173,7 +173,11 @@ impl RuleDecl { } pub fn rule_tag_predicate(&self) -> String { - format!("rt#{}#{}", self.rule().head().predicate(), self.id()) + if let Some(head_atom) = self.rule().head().atom() { + format!("rt#{}#{}", head_atom.predicate(), self.id()) + } else { + unimplemented!("Rule head is not an atom") + } } } diff --git a/core/src/compiler/front/ast/rule.rs b/core/src/compiler/front/ast/rule.rs index 425d771..fc9ea53 100644 --- a/core/src/compiler/front/ast/rule.rs +++ b/core/src/compiler/front/ast/rule.rs @@ -3,12 +3,12 @@ use super::*; #[derive(Clone, Debug, PartialEq)] #[doc(hidden)] pub struct RuleNode { - pub head: Atom, + pub head: RuleHead, pub body: Formula, } impl RuleNode { - pub fn new(head: Atom, body: Formula) -> Self { + pub fn new(head: RuleHead, body: Formula) -> Self { Self { head, body } } } @@ -16,7 +16,7 @@ impl RuleNode { pub type Rule = AstNode; impl Rule { - pub fn head(&self) -> &Atom { + pub fn head(&self) -> &RuleHead { &self.node.head } @@ -40,3 +40,59 @@ impl Into> for Rule { )] } } + +#[derive(Clone, Debug, PartialEq)] +#[doc(hidden)] +pub enum RuleHeadNode { + Atom(Atom), + Disjunction(Vec), +} + +pub type RuleHead = AstNode; + +impl RuleHead { + pub fn is_atomic(&self) -> bool { + match &self.node { + RuleHeadNode::Atom(_) => true, + RuleHeadNode::Disjunction(_) => false, + } + } + + pub fn is_disjunction(&self) -> bool { + match &self.node { + RuleHeadNode::Atom(_) => false, + RuleHeadNode::Disjunction(_) => true, + } + } + + pub fn atom(&self) -> Option<&Atom> { + match &self.node { + RuleHeadNode::Atom(atom) => Some(atom), + RuleHeadNode::Disjunction(_) => None, + } + } + + pub fn iter_predicates(&self) -> Vec<&String> { + match &self.node { + RuleHeadNode::Atom(atom) => vec![atom.predicate()], + RuleHeadNode::Disjunction(atoms) => atoms.iter().map(|atom| atom.predicate()).collect(), + } + } + + pub fn iter_arguments(&self) -> Vec<&Expr> { + match &self.node { + RuleHeadNode::Atom(atom) => atom.iter_arguments().collect(), + RuleHeadNode::Disjunction(atoms) => atoms + .iter() + .flat_map(|atom| atom.iter_arguments()) + .collect(), + } + } +} + +impl From for RuleHead { + fn from(atom: Atom) -> Self { + let loc = atom.location().clone_without_id(); + Self::new(loc, RuleHeadNode::Atom(atom)) + } +} diff --git a/core/src/compiler/front/ast/utils.rs b/core/src/compiler/front/ast/utils.rs index 0a1d291..4d4b54c 100644 --- a/core/src/compiler/front/ast/utils.rs +++ b/core/src/compiler/front/ast/utils.rs @@ -47,6 +47,7 @@ impl std::hash::Hash for AstNodeLocation { } impl AstNodeLocation { + /// When cloning a location, we want to keep everything but not the id. pub fn clone_without_id(&self) -> Self { Self { offset_span: self.offset_span.clone(), @@ -56,6 +57,7 @@ impl AstNodeLocation { } } + /// Create a location from a single offset span. pub fn from_offset_span(start: usize, end: usize) -> Self { Self { offset_span: Span::new(start, end), diff --git a/core/src/compiler/front/f2b/f2b.rs b/core/src/compiler/front/f2b/f2b.rs index f3e0e97..cd3b05a 100644 --- a/core/src/compiler/front/f2b/f2b.rs +++ b/core/src/compiler/front/f2b/f2b.rs @@ -189,50 +189,109 @@ impl FrontContext { } fn rule_decl_to_back_rules(&self, rd: &front::RuleDecl, temp_relations: &mut Vec) -> Vec { + let rule_loc = rd.rule().location(); + match &rd.rule().head().node { + front::RuleHeadNode::Atom(head) => { + self.atomic_rule_decl_to_back_rules(rule_loc, head, temp_relations) + } + front::RuleHeadNode::Disjunction(head_atoms) => { + self.disjunctive_rule_decl_to_back_rules(rule_loc, head_atoms, temp_relations) + } + } + } + + fn atomic_rule_decl_to_back_rules( + &self, + rule_loc: &AstNodeLocation, + head: &front::Atom, + temp_relations: &mut Vec, + ) -> Vec { let analysis = self.analysis.borrow(); // Basic information - let src_rule = rd.rule().clone(); - let pred = rd.rule().head().predicate(); + let pred = head.predicate(); let attributes = back::Attributes::new(); // Collect information for flattening let mut flatten_expr = FlattenExprContext::new(&analysis.type_inference, &self.foreign_predicate_registry); - flatten_expr.walk_atom(src_rule.head()); + flatten_expr.walk_atom(head); // Create the flattened expression that the head needs - let head_exprs = rd - .rule() - .head() + let head_exprs = head .iter_arguments() .map(|a| flatten_expr.collect_flattened_literals(a.location())) .flatten() .collect::>(); // Create the head that will be shared across all back rules - let args = rd - .rule() - .head() + let args = head .iter_arguments() .map(|a| flatten_expr.get_expr_term(a)) .collect(); - let head = back::Head { - predicate: pred.clone(), - args, - }; + let head = back::Head::atom(pred.clone(), args); + + // Get the back rules + let boundness_analysis = &self.analysis.borrow().boundness_analysis; + let rule_ctx = boundness_analysis.get_rule_context(rule_loc).unwrap(); + self.formula_to_back_rules( + &mut flatten_expr, + rule_loc, + attributes, + pred.clone(), + rule_ctx, + head, + head_exprs, + temp_relations, + ) + } + + fn disjunctive_rule_decl_to_back_rules( + &self, + rule_loc: &AstNodeLocation, + head_atoms: &[front::Atom], + temp_relations: &mut Vec, + ) -> Vec { + let analysis = self.analysis.borrow(); + + // Basic information + let pred = head_atoms[0].predicate(); + let attributes = back::Attributes::new(); + + // Collect information for flattening + let mut flatten_expr = FlattenExprContext::new(&analysis.type_inference, &self.foreign_predicate_registry); + for head in head_atoms { + flatten_expr.walk_atom(head); + } + + // Create the flattened expression that the head needs + let head_exprs = head_atoms + .iter() + .flat_map(|a| a.iter_arguments()) + .flat_map(|a| flatten_expr.collect_flattened_literals(a.location())) + .collect::>(); + + // Create the head that will be shared across all back rules + let back_head_atoms = head_atoms + .iter() + .map(|a| { + let args = a + .iter_arguments() + .map(|a| flatten_expr.get_expr_term(a)) + .collect(); + back::Atom::new(a.predicate().clone(), args) + }) + .collect(); + let head = back::Head::Disjunction(back_head_atoms); // Get the back rules + let boundness_analysis = &self.analysis.borrow().boundness_analysis; + let rule_ctx = boundness_analysis.get_rule_context(rule_loc).unwrap(); self.formula_to_back_rules( &mut flatten_expr, - src_rule.location(), + rule_loc, attributes, pred.clone(), - self - .analysis - .borrow() - .boundness_analysis - .get_rule_context(src_rule.location()) - .unwrap(), + rule_ctx, head, head_exprs, temp_relations, @@ -362,7 +421,7 @@ impl FrontContext { temp_relations.push(group_by_relation); // Create temporary rule(s) for group_by - let group_by_rule_head = back::Head::new(group_by_predicate.clone(), group_by_terms.clone()); + let group_by_rule_head = back::Head::atom(group_by_predicate.clone(), group_by_terms.clone()); let group_by_rules = self.formula_to_back_rules( flatten_expr, src_rule_loc, @@ -420,7 +479,7 @@ impl FrontContext { temp_relations.push(body_relation); // Get the rules for body - let body_head = back::Head::new(body_predicate.clone(), body_terms.clone()); + let body_head = back::Head::atom(body_predicate.clone(), body_terms.clone()); let body_rules = self.formula_to_back_rules( flatten_expr, src_rule_loc, diff --git a/core/src/compiler/front/grammar.lalrpop b/core/src/compiler/front/grammar.lalrpop index 88dbaa6..69466a5 100644 --- a/core/src/compiler/front/grammar.lalrpop +++ b/core/src/compiler/front/grammar.lalrpop @@ -615,22 +615,24 @@ ForallExistsReduceOpNode: ReduceOperatorNode = { ForallExistsReduceOp = Spanned; ForallExistsReduceNode: ForallExistsReduceNode = { - "(" ")" => { + "(" ")" => { ForallExistsReduceNode { + negate: negate.is_some(), operator: op, bindings: vec![], body: Box::new(f), group_by: g, } }, - "(" > ":" ")" => { + "(" > ":" ")" => { ForallExistsReduceNode { + negate: negate.is_some(), operator: op, bindings: bs, body: Box::new(f), group_by: g, } - } + }, } ForallExistsReduce = Spanned; @@ -838,8 +840,19 @@ UnitExpr: Expr = { => Expr::Call(c), } +RuleHeadNode: RuleHeadNode = { + => { + RuleHeadNode::Atom(a) + }, + "{" > "}" => { + RuleHeadNode::Disjunction(atoms) + }, +} + +RuleHead: RuleHead = Spanned; + RuleNode: RuleNode = { - DefineSymbol => { + DefineSymbol => { RuleNode { head, body } } } diff --git a/core/src/compiler/front/pretty.rs b/core/src/compiler/front/pretty.rs index 1fb930f..9429671 100644 --- a/core/src/compiler/front/pretty.rs +++ b/core/src/compiler/front/pretty.rs @@ -320,6 +320,24 @@ impl Display for Rule { } } +impl Display for RuleHead { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match &self.node { + RuleHeadNode::Atom(a) => a.fmt(f), + RuleHeadNode::Disjunction(d) => { + f.write_str("{")?; + for (i, a) in d.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + a.fmt(f)?; + } + f.write_str("}") + }, + } + } +} + impl Display for Formula { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { @@ -414,6 +432,9 @@ impl Display for Reduce { impl Display for ForallExistsReduce { fn fmt(&self, f: &mut Formatter<'_>) -> Result { + if self.is_negated() { + f.write_str("not ")?; + } self.operator().fmt(f)?; f.write_fmt(format_args!( "({}: {})", diff --git a/core/src/compiler/front/transformations/atomic_query.rs b/core/src/compiler/front/transformations/atomic_query.rs index 33933e8..b8feee2 100644 --- a/core/src/compiler/front/transformations/atomic_query.rs +++ b/core/src/compiler/front/transformations/atomic_query.rs @@ -55,7 +55,7 @@ impl NodeVisitorMut for TransformAtomicQuery { args: vec![vec![Formula::Atom(body_atom.into())], eq_constraints].concat(), }; let rule = RuleNode { - head: head_atom.into(), + head: Atom::default(head_atom).into(), body: Formula::Conjunction(conj.into()), }; self.to_add_rules.push(rule.into()); diff --git a/core/src/compiler/front/transformations/desugar_forall_exists.rs b/core/src/compiler/front/transformations/desugar_forall_exists.rs index d10d4bc..3b9752c 100644 --- a/core/src/compiler/front/transformations/desugar_forall_exists.rs +++ b/core/src/compiler/front/transformations/desugar_forall_exists.rs @@ -26,6 +26,9 @@ impl NodeVisitorMut for DesugarForallExists { fn visit_formula(&mut self, formula: &mut Formula) { match formula { Formula::ForallExistsReduce(r) => { + // Get the goal + let goal = !r.is_negated(); + // Generate a boolean variable let boolean_var_name = format!("r#desugar#{}", r.loc.id.unwrap()); let boolean_var_identifier: Identifier = IdentifierNode::new(boolean_var_name).into(); @@ -49,7 +52,7 @@ impl NodeVisitorMut for DesugarForallExists { let constraint = Constraint::default_with_expr(Expr::binary( BinaryOp::default_eq(), Expr::Variable(boolean_var.clone()), - Expr::boolean_true(), + Expr::boolean(goal), )); let constraint_formula = Formula::Constraint(constraint); diff --git a/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs b/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs index c63ad11..ffc71ea 100644 --- a/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs +++ b/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs @@ -28,7 +28,7 @@ impl NodeVisitorMut for TransformNonConstantFactToRule { // Transform this into a rule. First generate the head atom: // all the non-constant arguments will be replaced by a variable - let head: Atom = AtomNode { + let head_atom: Atom = AtomNode { predicate: head.node.predicate.clone(), args: head .iter_arguments() @@ -45,6 +45,7 @@ impl NodeVisitorMut for TransformNonConstantFactToRule { .collect(), } .into(); + let head: RuleHead = head_atom.into(); // For each non-constant variable, we create a equality constraint let eq_consts = non_const_var_expr_pairs diff --git a/core/src/compiler/front/visitor.rs b/core/src/compiler/front/visitor.rs index 667b245..dceeec3 100644 --- a/core/src/compiler/front/visitor.rs +++ b/core/src/compiler/front/visitor.rs @@ -35,6 +35,7 @@ pub trait NodeVisitor { node_visitor_func_def!(visit_query, Query); node_visitor_func_def!(visit_tag, Tag); node_visitor_func_def!(visit_rule, Rule); + node_visitor_func_def!(visit_rule_head, RuleHead); node_visitor_func_def!(visit_atom, Atom); node_visitor_func_def!(visit_neg_atom, NegAtom); node_visitor_func_def!(visit_attribute, Attribute); @@ -231,7 +232,7 @@ pub trait NodeVisitor { fn walk_fact_decl(&mut self, fact_decl: &FactDecl) { self.visit_fact_decl(fact_decl); - self.visit_location(&fact_decl.loc); + self.visit_location(fact_decl.location()); self.walk_tag(&fact_decl.node.tag); self.walk_atom(&fact_decl.node.atom); } @@ -263,10 +264,23 @@ pub trait NodeVisitor { fn walk_rule(&mut self, rule: &Rule) { self.visit_rule(rule); self.visit_location(&rule.loc); - self.walk_atom(&rule.node.head); + self.walk_rule_head(&rule.node.head); self.walk_formula(&rule.node.body); } + fn walk_rule_head(&mut self, rule_head: &RuleHead) { + self.visit_rule_head(rule_head); + self.visit_location(rule_head.location()); + match &rule_head.node { + RuleHeadNode::Atom(a) => self.walk_atom(a), + RuleHeadNode::Disjunction(d) => { + for atom in d { + self.walk_atom(atom); + } + }, + } + } + fn walk_formula(&mut self, formula: &Formula) { self.visit_formula(formula); match formula { @@ -528,6 +542,7 @@ macro_rules! impl_node_visitor_tuple { node_visitor_visit_node!(visit_query, Query, ($($id),*)); node_visitor_visit_node!(visit_tag, Tag, ($($id),*)); node_visitor_visit_node!(visit_rule, Rule, ($($id),*)); + node_visitor_visit_node!(visit_rule_head, RuleHead, ($($id),*)); node_visitor_visit_node!(visit_atom, Atom, ($($id),*)); node_visitor_visit_node!(visit_neg_atom, NegAtom, ($($id),*)); node_visitor_visit_node!(visit_attribute, Attribute, ($($id),*)); diff --git a/core/src/compiler/front/visitor_mut.rs b/core/src/compiler/front/visitor_mut.rs index 62b6f29..b3498fd 100644 --- a/core/src/compiler/front/visitor_mut.rs +++ b/core/src/compiler/front/visitor_mut.rs @@ -35,6 +35,7 @@ pub trait NodeVisitorMut { node_visitor_mut_func_def!(visit_query, Query); node_visitor_mut_func_def!(visit_tag, Tag); node_visitor_mut_func_def!(visit_rule, Rule); + node_visitor_mut_func_def!(visit_rule_head, RuleHead); node_visitor_mut_func_def!(visit_atom, Atom); node_visitor_mut_func_def!(visit_neg_atom, NegAtom); node_visitor_mut_func_def!(visit_attribute, Attribute); @@ -263,10 +264,23 @@ pub trait NodeVisitorMut { fn walk_rule(&mut self, rule: &mut Rule) { self.visit_rule(rule); self.visit_location(&mut rule.loc); - self.walk_atom(&mut rule.node.head); + self.walk_rule_head(&mut rule.node.head); self.walk_formula(&mut rule.node.body); } + fn walk_rule_head(&mut self, rule_head: &mut RuleHead) { + self.visit_rule_head(rule_head); + self.visit_location(rule_head.location_mut()); + match &mut rule_head.node { + RuleHeadNode::Atom(a) => self.walk_atom(a), + RuleHeadNode::Disjunction(d) => { + for atom in d { + self.walk_atom(atom); + } + }, + } + } + fn walk_formula(&mut self, formula: &mut Formula) { self.visit_formula(formula); match formula { @@ -526,6 +540,7 @@ macro_rules! impl_node_visitor_mut_tuple { node_visitor_mut_visit_node!(visit_query, Query, ($($id),*)); node_visitor_mut_visit_node!(visit_tag, Tag, ($($id),*)); node_visitor_mut_visit_node!(visit_rule, Rule, ($($id),*)); + node_visitor_mut_visit_node!(visit_rule_head, RuleHead, ($($id),*)); node_visitor_mut_visit_node!(visit_atom, Atom, ($($id),*)); node_visitor_mut_visit_node!(visit_neg_atom, NegAtom, ($($id),*)); node_visitor_mut_visit_node!(visit_attribute, Attribute, ($($id),*)); diff --git a/core/src/compiler/ram/ast.rs b/core/src/compiler/ram/ast.rs index 2665098..3eec9b9 100644 --- a/core/src/compiler/ram/ast.rs +++ b/core/src/compiler/ram/ast.rs @@ -179,22 +179,35 @@ pub struct Update { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum Dataflow { + // Base relation Unit(TupleType), + UntaggedVec(Vec), + Relation(String), + + // Unary operations + Project(Box, Expr), + Filter(Box, Expr), + Find(Box, Tuple), + + // Binary operations Union(Box, Box), Join(Box, Box), Intersect(Box, Box), Product(Box, Box), Antijoin(Box, Box), Difference(Box, Box), - Project(Box, Expr), - Filter(Box, Expr), - Find(Box, Tuple), + + // Aggregation + Reduce(Reduce), + + // Tag operations OverwriteOne(Box), + Exclusion(Box, Box), + + // Foreign predicates ForeignPredicateGround(String, Vec), ForeignPredicateConstraint(Box, String, Vec), ForeignPredicateJoin(Box, String, Vec), - Reduce(Reduce), - Relation(String), } impl Dataflow { @@ -246,6 +259,10 @@ impl Dataflow { Self::OverwriteOne(Box::new(self)) } + pub fn exclusion(self, right: Vec) -> Self { + Self::Exclusion(Box::new(self), Box::new(Self::UntaggedVec(right))) + } + pub fn foreign_predicate_constraint(self, predicate: String, args: Vec) -> Self { Self::ForeignPredicateConstraint(Box::new(self), predicate, args) } @@ -280,10 +297,12 @@ impl Dataflow { | Self::Find(d, _) | Self::OverwriteOne(d) | Self::ForeignPredicateConstraint(d, _, _) - | Self::ForeignPredicateJoin(d, _, _) => d.source_relations(), + | Self::ForeignPredicateJoin(d, _, _) + | Self::Exclusion(d, _) => d.source_relations(), Self::Reduce(r) => std::iter::once(r.source_relation()).collect(), Self::Relation(r) => std::iter::once(r).collect(), - Self::ForeignPredicateGround(_, _) => HashSet::new(), + Self::ForeignPredicateGround(_, _) + | Self::UntaggedVec(_) => HashSet::new(), } } } diff --git a/core/src/compiler/ram/dependency.rs b/core/src/compiler/ram/dependency.rs index 2fe8601..ccb4f07 100644 --- a/core/src/compiler/ram/dependency.rs +++ b/core/src/compiler/ram/dependency.rs @@ -68,7 +68,8 @@ impl Update { impl Dataflow { fn collect_dependency(&self, preds: &mut HashSet) { match self { - Self::Unit(_) => {} + Self::Unit(_) + | Self::UntaggedVec(_) => {} Self::Relation(r) => { preds.insert(r.clone()); } @@ -81,6 +82,9 @@ impl Dataflow { Self::OverwriteOne(d) => { d.collect_dependency(preds); } + Self::Exclusion(d, _) => { + d.collect_dependency(preds); + } Self::Find(d, _) => { d.collect_dependency(preds); } diff --git a/core/src/compiler/ram/optimizations/project_cascade.rs b/core/src/compiler/ram/optimizations/project_cascade.rs index a057a41..667dd88 100644 --- a/core/src/compiler/ram/optimizations/project_cascade.rs +++ b/core/src/compiler/ram/optimizations/project_cascade.rs @@ -57,11 +57,13 @@ fn project_cascade_on_dataflow(d0: &mut Dataflow) -> bool { Dataflow::Filter(d, _) => project_cascade_on_dataflow(&mut **d), Dataflow::Find(d, _) => project_cascade_on_dataflow(&mut **d), Dataflow::OverwriteOne(d) => project_cascade_on_dataflow(&mut **d), + Dataflow::Exclusion(d, _) => project_cascade_on_dataflow(&mut **d), Dataflow::ForeignPredicateConstraint(d, _, _) => project_cascade_on_dataflow(&mut **d), Dataflow::ForeignPredicateJoin(d, _, _) => project_cascade_on_dataflow(&mut **d), Dataflow::ForeignPredicateGround(_, _) | Dataflow::Unit(_) | Dataflow::Relation(_) - | Dataflow::Reduce(_) => false, + | Dataflow::Reduce(_) + | Dataflow::UntaggedVec(_) => false, } } diff --git a/core/src/compiler/ram/pretty.rs b/core/src/compiler/ram/pretty.rs index 9253076..142ece8 100644 --- a/core/src/compiler/ram/pretty.rs +++ b/core/src/compiler/ram/pretty.rs @@ -71,57 +71,40 @@ impl Dataflow { let next_indent = base_indent + indent_size; let padding = vec![' '; next_indent].into_iter().collect::(); match self { + // Base relations Self::Unit(t) => f.write_fmt(format_args!("Unit({})", t)), + Self::UntaggedVec(v) => f.write_fmt(format_args!("Vec([{}])", v.iter().map(|t| format!("{}", t)).collect::>().join(", "))), Self::Relation(r) => f.write_fmt(format_args!("Relation {}", r)), - Self::Reduce(r) => { - let group_by_predicate = match &r.group_by { - ReduceGroupByType::Join(group_by_predicate) => format!(" where {}", group_by_predicate), - ReduceGroupByType::Implicit => format!(" implicit group"), - _ => format!(""), - }; - f.write_fmt(format_args!( - "Aggregation {}({}{})", - r.op, r.predicate, group_by_predicate - )) - } - Self::ForeignPredicateGround(pred, args) => { - let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); - f.write_fmt(format_args!("ForeignPredicateGround[{}({})]", pred, args.join(", "))) - } - Self::ForeignPredicateConstraint(d, pred, args) => { - let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); - f.write_fmt(format_args!("ForeignPredicateConstraint[{}({})]\n{}", pred, args.join(", "), padding))?; - d.pretty_print(f, next_indent, indent_size) - } - Self::ForeignPredicateJoin(d, pred, args) => { - let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); - f.write_fmt(format_args!("ForeignPredicateJoin[{}({})]\n{}", pred, args.join(", "), padding))?; + + // Unary operations + Self::Project(d, project) => { + f.write_fmt(format_args!("Project[{:?}]\n{}", project, padding))?; d.pretty_print(f, next_indent, indent_size) } - Self::OverwriteOne(d) => { - f.write_fmt(format_args!("OverwriteOne\n{}", padding))?; + Self::Filter(d, filter) => { + f.write_fmt(format_args!("Filter[{:?}]\n{}", filter, padding))?; d.pretty_print(f, next_indent, indent_size) } Self::Find(d, tuple) => { f.write_fmt(format_args!("Find[{}]\n{}", tuple, padding))?; d.pretty_print(f, next_indent, indent_size) } - Self::Filter(d, filter) => { - f.write_fmt(format_args!("Filter[{:?}]\n{}", filter, padding))?; - d.pretty_print(f, next_indent, indent_size) - } - Self::Project(d, project) => { - f.write_fmt(format_args!("Project[{:?}]\n{}", project, padding))?; - d.pretty_print(f, next_indent, indent_size) + + // Binary operations + Self::Union(d1, d2) => { + f.write_fmt(format_args!("Union\n{}", padding))?; + d1.pretty_print(f, next_indent, indent_size)?; + f.write_fmt(format_args!("\n{}", padding))?; + d2.pretty_print(f, next_indent, indent_size) } - Self::Difference(d1, d2) => { - f.write_fmt(format_args!("Difference\n{}", padding))?; + Self::Join(d1, d2) => { + f.write_fmt(format_args!("Join\n{}", padding))?; d1.pretty_print(f, next_indent, indent_size)?; f.write_fmt(format_args!("\n{}", padding))?; d2.pretty_print(f, next_indent, indent_size) } - Self::Antijoin(d1, d2) => { - f.write_fmt(format_args!("Antijoin\n{}", padding))?; + Self::Intersect(d1, d2) => { + f.write_fmt(format_args!("Intersect\n{}", padding))?; d1.pretty_print(f, next_indent, indent_size)?; f.write_fmt(format_args!("\n{}", padding))?; d2.pretty_print(f, next_indent, indent_size) @@ -132,24 +115,57 @@ impl Dataflow { f.write_fmt(format_args!("\n{}", padding))?; d2.pretty_print(f, next_indent, indent_size) } - Self::Intersect(d1, d2) => { - f.write_fmt(format_args!("Intersect\n{}", padding))?; + Self::Antijoin(d1, d2) => { + f.write_fmt(format_args!("Antijoin\n{}", padding))?; d1.pretty_print(f, next_indent, indent_size)?; f.write_fmt(format_args!("\n{}", padding))?; d2.pretty_print(f, next_indent, indent_size) } - Self::Join(d1, d2) => { - f.write_fmt(format_args!("Join\n{}", padding))?; + Self::Difference(d1, d2) => { + f.write_fmt(format_args!("Difference\n{}", padding))?; d1.pretty_print(f, next_indent, indent_size)?; f.write_fmt(format_args!("\n{}", padding))?; d2.pretty_print(f, next_indent, indent_size) } - Self::Union(d1, d2) => { - f.write_fmt(format_args!("Union\n{}", padding))?; + + // Aggregation + Self::Reduce(r) => { + let group_by_predicate = match &r.group_by { + ReduceGroupByType::Join(group_by_predicate) => format!(" where {}", group_by_predicate), + ReduceGroupByType::Implicit => format!(" implicit group"), + _ => format!(""), + }; + f.write_fmt(format_args!( + "Aggregation {}({}{})", + r.op, r.predicate, group_by_predicate + )) + } + + Self::OverwriteOne(d) => { + f.write_fmt(format_args!("OverwriteOne\n{}", padding))?; + d.pretty_print(f, next_indent, indent_size) + } + Self::Exclusion(d1, d2) => { + f.write_fmt(format_args!("Exclusion\n{}", padding))?; d1.pretty_print(f, next_indent, indent_size)?; f.write_fmt(format_args!("\n{}", padding))?; d2.pretty_print(f, next_indent, indent_size) } + + Self::ForeignPredicateGround(pred, args) => { + let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); + f.write_fmt(format_args!("ForeignPredicateGround[{}({})]", pred, args.join(", "))) + } + Self::ForeignPredicateConstraint(d, pred, args) => { + let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); + f.write_fmt(format_args!("ForeignPredicateConstraint[{}({})]\n{}", pred, args.join(", "), padding))?; + d.pretty_print(f, next_indent, indent_size) + } + Self::ForeignPredicateJoin(d, pred, args) => { + let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); + f.write_fmt(format_args!("ForeignPredicateJoin[{}({})]\n{}", pred, args.join(", "), padding))?; + d.pretty_print(f, next_indent, indent_size) + } } } } diff --git a/core/src/compiler/ram/ram2rs.rs b/core/src/compiler/ram/ram2rs.rs index c150bd6..9aab626 100644 --- a/core/src/compiler/ram/ram2rs.rs +++ b/core/src/compiler/ram/ram2rs.rs @@ -297,6 +297,32 @@ impl ast::Dataflow { let ty = tuple_type_to_rs_type(tuple_type); quote! { iter.unit::<#ty>(iter.is_first_iteration()) } } + Self::UntaggedVec(_) => unimplemented!(), + Self::Relation(r) => { + let rel_ident = relation_name_to_rs_field_name(r); + let stratum_id = rel_to_strat_map[r]; + if stratum_id == curr_strat_id { + quote! { &#rel_ident } + } else { + let stratum_result = format_ident!("stratum_{}_result", stratum_id); + quote! { dataflow::collection(&#stratum_result.#rel_ident, iter.is_first_iteration()) } + } + } + Self::Project(d1, expr) => { + let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); + let rs_expr = expr_to_rs_expr(expr); + quote! { dataflow::project(#rs_d1, |t| #rs_expr) } + } + Self::Filter(d1, expr) => { + let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); + let rs_expr = expr_to_rs_expr(expr); + quote! { dataflow::filter(#rs_d1, |t| #rs_expr) } + } + Self::Find(d1, tuple) => { + let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); + let rs_tuple = tuple_to_rs_tuple(tuple); + quote! { dataflow::find(#rs_d1, #rs_tuple) } + } Self::Union(d1, d2) => { let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); let rs_d2 = d2.to_rs_dataflow(curr_strat_id, rel_to_strat_map); @@ -327,28 +353,6 @@ impl ast::Dataflow { let rs_d2 = d2.to_rs_dataflow(curr_strat_id, rel_to_strat_map); quote! { iter.difference(#rs_d1, #rs_d2) } } - Self::Project(d1, expr) => { - let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); - let rs_expr = expr_to_rs_expr(expr); - quote! { dataflow::project(#rs_d1, |t| #rs_expr) } - } - Self::Filter(d1, expr) => { - let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); - let rs_expr = expr_to_rs_expr(expr); - quote! { dataflow::filter(#rs_d1, |t| #rs_expr) } - } - Self::Find(d1, tuple) => { - let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); - let rs_tuple = tuple_to_rs_tuple(tuple); - quote! { dataflow::find(#rs_d1, #rs_tuple) } - } - Self::OverwriteOne(d1) => { - let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); - quote! { dataflow::overwrite_one(#rs_d1) } - } - Self::ForeignPredicateGround(_, _) => unimplemented!(), - Self::ForeignPredicateConstraint(_, _, _) => unimplemented!(), - Self::ForeignPredicateJoin(_, _, _) => unimplemented!(), Self::Reduce(r) => { let get_col = |r| { let rel_ident = relation_name_to_rs_field_name(r); @@ -388,16 +392,14 @@ impl ast::Dataflow { } } } - Self::Relation(r) => { - let rel_ident = relation_name_to_rs_field_name(r); - let stratum_id = rel_to_strat_map[r]; - if stratum_id == curr_strat_id { - quote! { &#rel_ident } - } else { - let stratum_result = format_ident!("stratum_{}_result", stratum_id); - quote! { dataflow::collection(&#stratum_result.#rel_ident, iter.is_first_iteration()) } - } + Self::OverwriteOne(d1) => { + let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); + quote! { dataflow::overwrite_one(#rs_d1) } } + Self::Exclusion(_, _) => unimplemented!(), + Self::ForeignPredicateGround(_, _) => unimplemented!(), + Self::ForeignPredicateConstraint(_, _, _) => unimplemented!(), + Self::ForeignPredicateJoin(_, _, _) => unimplemented!(), } } } diff --git a/core/src/runtime/dynamic/dataflow/batching/batch.rs b/core/src/runtime/dynamic/dataflow/batching/batch.rs index 11b22df..4cab6b2 100644 --- a/core/src/runtime/dynamic/dataflow/batching/batch.rs +++ b/core/src/runtime/dynamic/dataflow/batching/batch.rs @@ -7,6 +7,7 @@ use super::super::*; #[derive(Clone)] pub enum DynamicBatch<'a, Prov: Provenance> { Vec(std::slice::Iter<'a, DynamicElement>), + UntaggedVec(&'a Prov, std::slice::Iter<'a, Tuple>), SourceVec(std::vec::IntoIter>), DynamicRelationStable(DynamicRelationStableBatch<'a, Prov>), DynamicRelationRecent(DynamicRelationRecentBatch<'a, Prov>), @@ -21,6 +22,7 @@ pub enum DynamicBatch<'a, Prov: Provenance> { Antijoin(DynamicAntijoinBatch<'a, Prov>), ForeignPredicateConstraint(ForeignPredicateConstraintBatch<'a, Prov>), ForeignPredicateJoin(ForeignPredicateJoinBatch<'a, Prov>), + Exclusion(DynamicExclusionBatch<'a, Prov>), } impl<'a, Prov: Provenance> DynamicBatch<'a, Prov> { @@ -28,6 +30,10 @@ impl<'a, Prov: Provenance> DynamicBatch<'a, Prov> { Self::Vec(v.iter()) } + pub fn untagged_vec(ctx: &'a Prov, v: std::slice::Iter<'a, Tuple>) -> Self { + Self::UntaggedVec(ctx, v) + } + pub fn source_vec(v: Vec>) -> Self { Self::SourceVec(v.into_iter()) } @@ -106,6 +112,7 @@ impl<'a, Prov: Provenance> Iterator for DynamicBatch<'a, Prov> { fn next(&mut self) -> Option { match self { Self::Vec(iter) => iter.next().map(Clone::clone), + Self::UntaggedVec(ctx, iter) => iter.next().map(|t| DynamicElement::new(t.clone(), ctx.one())), Self::SourceVec(iter) => iter.next(), Self::DynamicRelationStable(b) => b.next(), Self::DynamicRelationRecent(b) => b.next(), @@ -120,6 +127,7 @@ impl<'a, Prov: Provenance> Iterator for DynamicBatch<'a, Prov> { Self::Antijoin(a) => a.next(), Self::ForeignPredicateConstraint(b) => b.next(), Self::ForeignPredicateJoin(b) => b.next(), + Self::Exclusion(e) => e.next(), } } } diff --git a/core/src/runtime/dynamic/dataflow/batching/binary.rs b/core/src/runtime/dynamic/dataflow/batching/binary.rs index 9ab7f53..2da258b 100644 --- a/core/src/runtime/dynamic/dataflow/batching/binary.rs +++ b/core/src/runtime/dynamic/dataflow/batching/binary.rs @@ -7,6 +7,7 @@ pub enum BatchBinaryOp<'a, Prov: Provenance> { Product(ProductOp<'a, Prov>), Difference(DifferenceOp<'a, Prov>), Antijoin(AntijoinOp<'a, Prov>), + Exclusion(ExclusionOp<'a, Prov>), } impl<'a, Prov: Provenance> BatchBinaryOp<'a, Prov> { @@ -17,6 +18,7 @@ impl<'a, Prov: Provenance> BatchBinaryOp<'a, Prov> { Self::Product(p) => p.apply(b1, b2), Self::Difference(d) => d.apply(b1, b2), Self::Antijoin(a) => a.apply(b1, b2), + Self::Exclusion(e) => e.apply(b1, b2), } } } diff --git a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs index 316e82c..7d37597 100644 --- a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs +++ b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs @@ -9,9 +9,11 @@ use super::*; pub enum DynamicDataflow<'a, Prov: Provenance> { StableUnit(DynamicStableUnitDataflow<'a, Prov>), RecentUnit(DynamicRecentUnitDataflow<'a, Prov>), - Vec(&'a Vec>), + Vec(&'a DynamicElements), + UntaggedVec(DynamicUntaggedVec<'a, Prov>), DynamicStableCollection(DynamicStableCollectionDataflow<'a, Prov>), DynamicRecentCollection(DynamicRecentCollectionDataflow<'a, Prov>), + DynamicExclusion(DynamicExclusionDataflow<'a, Prov>), DynamicRelation(DynamicRelationDataflow<'a, Prov>), OverwriteOne(DynamicOverwriteOneDataflow<'a, Prov>), Project(DynamicProjectDataflow<'a, Prov>), @@ -30,10 +32,14 @@ pub enum DynamicDataflow<'a, Prov: Provenance> { } impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { - pub fn vec(vec: &'a Vec>) -> Self { + pub fn vec(vec: &'a DynamicElements) -> Self { Self::Vec(vec) } + pub fn untagged_vec(ctx: &'a Prov, vec: &'a Vec) -> Self { + Self::UntaggedVec(DynamicUntaggedVec::new(ctx, vec)) + } + pub fn recent_unit(ctx: &'a Prov, tuple_type: TupleType) -> Self { Self::RecentUnit(DynamicRecentUnitDataflow::new(ctx, tuple_type)) } @@ -179,25 +185,44 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { }) } + pub fn dynamic_exclusion(self, other: Self, ctx: &'a Prov) -> Self { + Self::DynamicExclusion(DynamicExclusionDataflow::new(self, other, ctx)) + } + pub fn iter_stable(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { match self { + // Static relations Self::StableUnit(i) => i.iter_stable(runtime), Self::RecentUnit(i) => i.iter_stable(runtime), Self::Vec(_) => DynamicBatches::Empty, + Self::UntaggedVec(i) => i.iter_stable(runtime), + + // Dynamic relations Self::DynamicStableCollection(dc) => dc.iter_stable(runtime), Self::DynamicRecentCollection(dc) => dc.iter_stable(runtime), Self::DynamicRelation(dr) => dr.iter_stable(runtime), - Self::OverwriteOne(d) => d.iter_stable(runtime), + + // Unary operations Self::Project(p) => p.iter_stable(runtime), Self::Filter(f) => f.iter_stable(runtime), Self::Find(f) => f.iter_stable(runtime), - Self::Intersect(i) => i.iter_stable(runtime), + + // Binary operations + Self::Union(u) => u.iter_stable(runtime), Self::Join(j) => j.iter_stable(runtime), Self::Product(p) => p.iter_stable(runtime), - Self::Union(u) => u.iter_stable(runtime), + Self::Intersect(i) => i.iter_stable(runtime), Self::Difference(d) => d.iter_stable(runtime), Self::Antijoin(a) => a.iter_stable(runtime), + + // Aggregation operations Self::Aggregate(a) => a.iter_stable(runtime), + + // Tag operations + Self::OverwriteOne(d) => d.iter_stable(runtime), + Self::DynamicExclusion(d) => d.iter_stable(runtime), + + // Foreign predicates Self::ForeignPredicateGround(d) => d.iter_stable(runtime), Self::ForeignPredicateConstraint(d) => d.iter_stable(runtime), Self::ForeignPredicateJoin(d) => d.iter_stable(runtime), @@ -206,23 +231,38 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { pub fn iter_recent(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { match self { + // Static relations Self::StableUnit(i) => i.iter_recent(runtime), Self::RecentUnit(i) => i.iter_recent(runtime), Self::Vec(v) => DynamicBatches::single(DynamicBatch::vec(v)), + Self::UntaggedVec(i) => i.iter_recent(runtime), + + // Dynamic relations Self::DynamicStableCollection(dc) => dc.iter_recent(runtime), Self::DynamicRecentCollection(dc) => dc.iter_recent(runtime), Self::DynamicRelation(dr) => dr.iter_recent(runtime), - Self::OverwriteOne(d) => d.iter_recent(runtime), + + // Unary operations Self::Project(p) => p.iter_recent(runtime), Self::Filter(f) => f.iter_recent(runtime), Self::Find(f) => f.iter_recent(runtime), - Self::Intersect(i) => i.iter_recent(runtime), + + // Binary operations + Self::Union(u) => u.iter_recent(runtime), Self::Join(j) => j.iter_recent(runtime), Self::Product(p) => p.iter_recent(runtime), - Self::Union(u) => u.iter_recent(runtime), + Self::Intersect(i) => i.iter_recent(runtime), Self::Difference(d) => d.iter_recent(runtime), Self::Antijoin(a) => a.iter_recent(runtime), + + // Aggregation operations Self::Aggregate(a) => a.iter_recent(runtime), + + // Tag operations + Self::OverwriteOne(d) => d.iter_recent(runtime), + Self::DynamicExclusion(d) => d.iter_recent(runtime), + + // Foreign predicates Self::ForeignPredicateGround(d) => d.iter_recent(runtime), Self::ForeignPredicateConstraint(d) => d.iter_recent(runtime), Self::ForeignPredicateJoin(d) => d.iter_recent(runtime), diff --git a/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs b/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs new file mode 100644 index 0000000..1235891 --- /dev/null +++ b/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs @@ -0,0 +1,153 @@ +use std::collections::*; + +use crate::common::tuple::*; +use crate::common::input_tag::*; +use crate::runtime::provenance::*; +use crate::utils::*; + +use super::*; + +type VisitedExclusionMap = ::RcCell>; + +#[derive(Clone)] +pub struct DynamicExclusionDataflow<'a, Prov: Provenance> { + // The left dataflow which generates base tuples + pub left: Box>, + + // The right dataflow which generates tuples for exclusion + // Current assumption is that the right dataflow is an `UntaggedVec` + pub right: Box>, + + // The provenance context + pub ctx: &'a Prov, + + // A map of tuples that have been visited to their mutual exclusion IDs + visited: VisitedExclusionMap, +} + +impl<'a, Prov: Provenance> DynamicExclusionDataflow<'a, Prov> { + pub fn new(left: DynamicDataflow<'a, Prov>, right: DynamicDataflow<'a, Prov>, ctx: &'a Prov) -> Self { + Self { + left: Box::new(left), + right: Box::new(right), + ctx, + visited: RcFamily::new_rc_cell(HashMap::new()), + } + } + + pub fn iter_recent(&self, env: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + let left = self.left.iter_recent(env); + let right = self.right.iter_recent(env); + let op = ExclusionOp::new(env, self.ctx, RcFamily::clone_rc_cell(&self.visited)); + DynamicBatches::binary(left, right, BatchBinaryOp::Exclusion(op)) + } + + pub fn iter_stable(&self, env: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + let left = self.left.iter_stable(env); + let right = self.right.iter_recent(env); + let op = ExclusionOp::new(env, self.ctx, RcFamily::clone_rc_cell(&self.visited)); + DynamicBatches::binary(left, right, BatchBinaryOp::Exclusion(op)) + } +} + +#[derive(Clone)] +pub struct ExclusionOp<'a, Prov: Provenance> { + pub runtime: &'a RuntimeEnvironment, + pub ctx: &'a Prov, + pub visited_exclusion_map: VisitedExclusionMap, +} + +impl<'a, Prov: Provenance> ExclusionOp<'a, Prov> { + pub fn new(runtime: &'a RuntimeEnvironment, ctx: &'a Prov, visited_exclusion_map: VisitedExclusionMap) -> Self { + Self { runtime, ctx, visited_exclusion_map } + } + + pub fn apply(&self, left: DynamicBatch<'a, Prov>, right: DynamicBatch<'a, Prov>) -> DynamicBatch<'a, Prov> { + DynamicBatch::Exclusion(DynamicExclusionBatch::new( + self.runtime, + self.ctx, + RcFamily::clone_rc_cell(&self.visited_exclusion_map), + left, + right, + )) + } +} + +#[derive(Clone)] +pub struct DynamicExclusionBatch<'a, Prov: Provenance> { + // Basic information + pub runtime: &'a RuntimeEnvironment, + pub ctx: &'a Prov, + pub visited_exclusion_map: VisitedExclusionMap, + + // Batches + pub left: Box>, + pub left_curr: Option>, + pub right_source: Box>, + pub right_clone: Box>, + pub curr_exclusion_id: Option, +} + +impl<'a, Prov: Provenance> DynamicExclusionBatch<'a, Prov> { + pub fn new( + runtime: &'a RuntimeEnvironment, + ctx: &'a Prov, + visited_exclusion_map: VisitedExclusionMap, + mut left: DynamicBatch<'a, Prov>, + right: DynamicBatch<'a, Prov>, + ) -> Self { + let right_clone = right.clone(); + let left_curr = left.next(); + Self { + runtime, + ctx, + visited_exclusion_map, + left: Box::new(left), + left_curr, + right_source: Box::new(right), + right_clone: Box::new(right_clone), + curr_exclusion_id: None, + } + } +} + +impl<'a, Prov: Provenance> Iterator for DynamicExclusionBatch<'a, Prov> { + type Item = DynamicElement; + + fn next(&mut self) -> Option { + loop { + if let Some(left) = &self.left_curr { + // First get an exclusion ID + let exc_id = if let Some(id) = RcFamily::get_rc_cell(&self.visited_exclusion_map, |m| m.get(&left.tuple).cloned()) { + // If the left tuple has been visited, directly pull the exclusion id + id + } else if let Some(id) = self.curr_exclusion_id { + // Or we have already generated a new ID for this tuple + id + } else { + // Otherwise, generate a new ID + let id = self.runtime.allocate_new_exclusion_id(); + RcFamily::get_rc_cell_mut(&self.visited_exclusion_map, |m| m.insert(left.tuple.clone(), id)); + id + }; + + // Then, iterate through the right + if let Some(right) = self.right_clone.next() { + // Create a tuple combining left and right + let tuple: Tuple = (left.tuple.clone(), right.tuple.clone()).into(); + let me_input_tag = Prov::InputTag::from_dynamic_input_tag(&DynamicInputTag::Exclusive(exc_id)); + let me_tag = self.ctx.tagging_optional_fn(me_input_tag); + let tag = self.ctx.mult(&left.tag, &me_tag); + return Some(DynamicElement::new(tuple, tag)); + } else { + // Move on to the next left element and reset other states + self.left_curr = self.left.next(); + self.right_clone = self.right_source.clone(); + self.curr_exclusion_id = None; + } + } else { + return None; + } + } + } +} diff --git a/core/src/runtime/dynamic/dataflow/mod.rs b/core/src/runtime/dynamic/dataflow/mod.rs index 843458b..91cd7ec 100644 --- a/core/src/runtime/dynamic/dataflow/mod.rs +++ b/core/src/runtime/dynamic/dataflow/mod.rs @@ -5,6 +5,7 @@ mod utils; mod antijoin; mod difference; mod dynamic_collection; +mod dynamic_exclusion; mod dynamic_dataflow; mod dynamic_relation; mod filter; @@ -18,6 +19,7 @@ mod project; mod static_relation; mod union; mod unit; +mod untagged_vec; // Imports use crate::runtime::dynamic::*; @@ -33,6 +35,7 @@ use antijoin::*; use difference::*; use dynamic_collection::*; pub use dynamic_dataflow::*; +use dynamic_exclusion::*; use dynamic_relation::*; use filter::*; use find::*; @@ -44,3 +47,4 @@ use product::*; use project::*; use union::*; use unit::*; +use untagged_vec::*; diff --git a/core/src/runtime/dynamic/dataflow/untagged_vec.rs b/core/src/runtime/dynamic/dataflow/untagged_vec.rs new file mode 100644 index 0000000..bafdc69 --- /dev/null +++ b/core/src/runtime/dynamic/dataflow/untagged_vec.rs @@ -0,0 +1,25 @@ +use crate::common::tuple::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct DynamicUntaggedVec<'a, Prov: Provenance> { + pub ctx: &'a Prov, + pub tuples: &'a Vec, +} + +impl<'a, Prov: Provenance> DynamicUntaggedVec<'a, Prov> { + pub fn new(ctx: &'a Prov, tuples: &'a Vec) -> Self { + Self { ctx, tuples } + } + + pub fn iter_recent(&self, _: &RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + DynamicBatches::single(DynamicBatch::untagged_vec(self.ctx, self.tuples.iter())) + } + + pub fn iter_stable(&self, _: &RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + DynamicBatches::Empty + } +} diff --git a/core/src/runtime/dynamic/iteration.rs b/core/src/runtime/dynamic/iteration.rs index 1d24cf2..de2a212 100644 --- a/core/src/runtime/dynamic/iteration.rs +++ b/core/src/runtime/dynamic/iteration.rs @@ -218,14 +218,14 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { } } - fn build_dynamic_update(&'a self, ctx: &'a Prov, update: &Update) -> DynamicUpdate<'a, Prov> { + fn build_dynamic_update(&'a self, ctx: &'a Prov, update: &'a Update) -> DynamicUpdate<'a, Prov> { DynamicUpdate { target: self.unsafe_get_dynamic_relation(&update.target), dataflow: self.build_dynamic_dataflow(ctx, &update.dataflow), } } - fn build_dynamic_dataflow(&'a self, ctx: &'a Prov, dataflow: &Dataflow) -> DynamicDataflow<'a, Prov> { + fn build_dynamic_dataflow(&'a self, ctx: &'a Prov, dataflow: &'a Dataflow) -> DynamicDataflow<'a, Prov> { match dataflow { Dataflow::Unit(t) => { if self.is_first_iteration() { @@ -234,6 +234,9 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { DynamicDataflow::stable_unit(ctx, t.clone()) } } + Dataflow::UntaggedVec(v) => { + DynamicDataflow::untagged_vec(ctx, v) + } Dataflow::Relation(c) => { if self.input_dynamic_collections.contains_key(c) { self.build_dynamic_collection(c) @@ -251,6 +254,7 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { self.build_dynamic_dataflow(ctx, d).foreign_predicate_join(p.clone(), a.clone(), ctx) } Dataflow::OverwriteOne(d) => self.build_dynamic_dataflow(ctx, d).overwrite_one(ctx), + Dataflow::Exclusion(d1, d2) => self.build_dynamic_dataflow(ctx, d1).dynamic_exclusion(self.build_dynamic_dataflow(ctx, d2), ctx), Dataflow::Filter(d, e) => self.build_dynamic_dataflow(ctx, d).filter(e.clone()), Dataflow::Find(d, k) => self.build_dynamic_dataflow(ctx, d).find(k.clone()), Dataflow::Project(d, e) => self.build_dynamic_dataflow(ctx, d).project(e.clone()), diff --git a/core/src/runtime/dynamic/relation.rs b/core/src/runtime/dynamic/relation.rs index 9f7fc64..8db4a76 100644 --- a/core/src/runtime/dynamic/relation.rs +++ b/core/src/runtime/dynamic/relation.rs @@ -27,7 +27,7 @@ impl DynamicRelation { } } - pub fn insert_untagged(&self, ctx: &mut Prov, data: Vec) + pub fn insert_untagged(&self, ctx: &Prov, data: Vec) where Tup: Into, { @@ -35,7 +35,7 @@ impl DynamicRelation { self.insert_tagged(ctx, elements); } - pub fn insert_untagged_with_monitor(&self, ctx: &mut Prov, data: Vec, m: &M) + pub fn insert_untagged_with_monitor(&self, ctx: &Prov, data: Vec, m: &M) where Tup: Into, M: Monitor, @@ -44,7 +44,7 @@ impl DynamicRelation { self.insert_tagged_with_monitor(ctx, elements, m); } - pub fn insert_one_tagged(&self, ctx: &mut Prov, input_tag: Option>, tuple: Tup) + pub fn insert_one_tagged(&self, ctx: &Prov, input_tag: Option>, tuple: Tup) where Tup: Into, { @@ -64,7 +64,7 @@ impl DynamicRelation { self.insert_tagged_with_monitor(ctx, vec![(input_tag, tuple)], m); } - pub fn insert_dynamically_tagged(&self, ctx: &mut Prov, data: Vec<(DynamicInputTag, Tup)>) + pub fn insert_dynamically_tagged(&self, ctx: &Prov, data: Vec<(DynamicInputTag, Tup)>) where Tup: Into, { @@ -78,7 +78,7 @@ impl DynamicRelation { self.insert_tagged(ctx, elements); } - pub fn insert_dynamically_tagged_with_monitor(&self, ctx: &mut Prov, data: Vec<(DynamicInputTag, Tup)>, m: &M) + pub fn insert_dynamically_tagged_with_monitor(&self, ctx: &Prov, data: Vec<(DynamicInputTag, Tup)>, m: &M) where Tup: Into, M: Monitor, @@ -93,7 +93,7 @@ impl DynamicRelation { self.insert_tagged_with_monitor(ctx, elements, m); } - pub fn insert_tagged(&self, ctx: &mut Prov, data: Vec<(Option>, Tup)>) + pub fn insert_tagged(&self, ctx: &Prov, data: Vec<(Option>, Tup)>) where Tup: Into, { @@ -105,7 +105,7 @@ impl DynamicRelation { self.insert_dataflow_recent(ctx, &dataflow, &RuntimeEnvironment::default()); } - pub fn insert_tagged_with_monitor(&self, ctx: &mut Prov, data: Vec<(Option>, Tup)>, m: &M) + pub fn insert_tagged_with_monitor(&self, ctx: &Prov, data: Vec<(Option>, Tup)>, m: &M) where Tup: Into, M: Monitor, diff --git a/core/src/runtime/env/environment.rs b/core/src/runtime/env/environment.rs index afc2bc0..cbd225a 100644 --- a/core/src/runtime/env/environment.rs +++ b/core/src/runtime/env/environment.rs @@ -9,6 +9,7 @@ use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; use crate::common::tuple::*; use crate::common::value_type::*; +use crate::utils::*; #[derive(Clone, Debug)] pub struct RuntimeEnvironment { @@ -29,6 +30,9 @@ pub struct RuntimeEnvironment { /// Foreign predicate registry pub predicate_registry: ForeignPredicateRegistry, + + /// Mutual exclusion ID allocator + pub exclusion_id_allocator: Arc>, } impl Default for RuntimeEnvironment { @@ -46,6 +50,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), + exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), } } @@ -57,6 +62,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), + exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), } } @@ -71,6 +77,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ffr, predicate_registry: fpr, + exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), } } @@ -82,6 +89,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ffr, predicate_registry: ForeignPredicateRegistry::std(), + exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), } } @@ -97,6 +105,10 @@ impl RuntimeEnvironment { self.iter_limit = None; } + pub fn allocate_new_exclusion_id(&self) -> usize { + self.exclusion_id_allocator.lock().unwrap().alloc() + } + pub fn eval(&self, expr: &Expr, tuple: &Tuple) -> Option { match expr { Expr::Tuple(t) => Some(Tuple::Tuple( diff --git a/core/src/runtime/env/mod.rs b/core/src/runtime/env/mod.rs index 44b2ac1..da3cfdd 100644 --- a/core/src/runtime/env/mod.rs +++ b/core/src/runtime/env/mod.rs @@ -1,7 +1,5 @@ mod environment; -mod literals; mod options; pub use environment::*; -pub use literals::*; pub use options::*; diff --git a/core/src/runtime/env/options.rs b/core/src/runtime/env/options.rs index f59dd83..cb848c4 100644 --- a/core/src/runtime/env/options.rs +++ b/core/src/runtime/env/options.rs @@ -3,10 +3,12 @@ use std::sync::*; use rand::rngs::SmallRng; use rand::SeedableRng; -use super::*; use crate::common::constants::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; +use crate::utils::*; + +use super::*; /// The options to create a runtime environment #[derive(Clone, Debug)] @@ -41,6 +43,7 @@ impl RuntimeEnvironmentOptions { iter_limit: self.iter_limit, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), + exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), } } } diff --git a/core/src/runtime/provenance/discrete/proofs.rs b/core/src/runtime/provenance/discrete/proofs.rs index 1517921..dcfcb6c 100644 --- a/core/src/runtime/provenance/discrete/proofs.rs +++ b/core/src/runtime/provenance/discrete/proofs.rs @@ -20,6 +20,12 @@ impl Proof { facts: p1.facts.iter().chain(p2.facts.iter()).cloned().collect(), } } + + pub fn from_facts>(i: I) -> Self { + Self { + facts: BTreeSet::from_iter(i), + } + } } impl std::fmt::Debug for Proof { @@ -91,6 +97,12 @@ impl Proofs { .collect(), } } + + pub fn from_proofs>(i: I) -> Self { + Self { + proofs: BTreeSet::from_iter(i), + } + } } impl std::fmt::Debug for Proofs { diff --git a/core/src/utils/id_allocator.rs b/core/src/utils/id_allocator.rs index 89b0654..879ab19 100644 --- a/core/src/utils/id_allocator.rs +++ b/core/src/utils/id_allocator.rs @@ -4,6 +4,14 @@ pub struct IdAllocator { } impl IdAllocator { + pub fn new() -> Self { + Self { id: 0 } + } + + pub fn new_with_start(start: usize) -> Self { + Self { id: start } + } + pub fn alloc(&mut self) -> usize { let result = self.id; self.id += 1; diff --git a/core/tests/integrate/basic.rs b/core/tests/integrate/basic.rs index 0da2220..d18a99f 100644 --- a/core/tests/integrate/basic.rs +++ b/core/tests/integrate/basic.rs @@ -736,6 +736,17 @@ fn test_exists_with_where_clause_2() { ) } +#[test] +fn test_not_exists_1() { + expect_interpret_result( + r#" + rel color = {(0, "red"), (1, "green")} + rel result() :- not exists(o: color(o, "blue")) + "#, + ("result", vec![()].into()), + ) +} + #[test] fn type_cast_to_string_1() { expect_interpret_result( @@ -1158,3 +1169,30 @@ fn string_plus_string_1() { ("full_name", vec![("Alice Lee".to_string(),)]), ) } + +#[test] +fn disjunctive_1() { + let prov = proofs::ProofsProvenance::::default(); + + // Pre-generate true tags and false tags + let true_tag = proofs::Proofs::from_proofs(vec![ + proofs::Proof::from_facts(vec![0, 1, 2, 4].into_iter()), + proofs::Proof::from_facts(vec![0, 1, 2, 5].into_iter()), + proofs::Proof::from_facts(vec![0, 1, 3, 4].into_iter()), + ].into_iter()); + let false_tag = proofs::Proofs::from_proofs(vec![ + proofs::Proof::from_facts(vec![0, 1, 3, 5].into_iter()), + ].into_iter()); + + // Test + expect_interpret_result_with_tag( + r#" + rel var = {1, 2} + rel { assign(x, true); assign(x, false) } = var(x) + rel result(a || b) = assign(1, a) and assign(2, b) + "#, + prov, + ("result", vec![(true_tag, (true,)), (false_tag, (false,))]), + proofs::Proofs::eq, + ) +} diff --git a/core/tests/runtime/dataflow/dyn_aggregate.rs b/core/tests/runtime/dataflow/dyn_aggregate.rs index 199baae..3196075 100644 --- a/core/tests/runtime/dataflow/dyn_aggregate.rs +++ b/core/tests/runtime/dataflow/dyn_aggregate.rs @@ -7,8 +7,8 @@ use scallop_core::testing::*; #[test] fn test_dynamic_aggregate_count_1() { - let mut ctx = unit::UnitProvenance::default(); - let mut rt = RuntimeEnvironment::default(); + let ctx = unit::UnitProvenance::default(); + let rt = RuntimeEnvironment::default(); // Relations let mut source_1 = DynamicRelation::::new(); @@ -16,15 +16,15 @@ fn test_dynamic_aggregate_count_1() { let mut target = DynamicRelation::::new(); // Initial - source_1.insert_untagged(&mut ctx, vec![(0i8, 1i8), (1i8, 2i8), (3i8, 4i8), (3i8, 5i8)]); - source_2.insert_untagged(&mut ctx, vec![(1i8, 1i8), (1i8, 2i8), (3i8, 5i8)]); + source_1.insert_untagged(&ctx, vec![(0i8, 1i8), (1i8, 2i8), (3i8, 4i8), (3i8, 5i8)]); + source_2.insert_untagged(&ctx, vec![(1i8, 1i8), (1i8, 2i8), (3i8, 5i8)]); // Iterate until fixpoint while source_1.changed(&ctx) || source_2.changed(&ctx) || target.changed(&ctx) { target.insert_dataflow_recent( &ctx, &DynamicDataflow::from(&source_1).intersect(DynamicDataflow::from(&source_2), &ctx), - &mut rt, + &rt, ) } @@ -41,7 +41,7 @@ fn test_dynamic_aggregate_count_1() { &ctx, ) .into(), - &mut rt, + &rt, ); first_time = false; } diff --git a/core/tests/runtime/dataflow/dyn_exclusion.rs b/core/tests/runtime/dataflow/dyn_exclusion.rs new file mode 100644 index 0000000..6fc12ed --- /dev/null +++ b/core/tests/runtime/dataflow/dyn_exclusion.rs @@ -0,0 +1,41 @@ +use scallop_core::common::expr::*; +use scallop_core::runtime::env::*; +use scallop_core::runtime::dynamic::*; +use scallop_core::runtime::provenance::*; +use scallop_core::testing::*; +use scallop_core::utils::*; + +#[test] +fn test_dynamic_exclusion_1() { + let ctx = proofs::ProofsProvenance::::default(); + let rt = RuntimeEnvironment::default(); + + // Relations + let mut source = DynamicRelation::>::new(); + let mut target = DynamicRelation::>::new(); + source.insert_untagged(&ctx, vec![(0,), (1,)]); + + // Untagged vec for exclusion + let exc = vec![("red".to_string(),).into(), ("blue".to_string(),).into()]; + + // Iterate until fixpoint + let mut first_time = true; + while source.changed(&ctx) || target.changed(&ctx) || first_time { + target.insert_dataflow_recent( + &ctx, + &dataflow::DynamicDataflow::from(&source) + .dynamic_exclusion(dataflow::DynamicDataflow::untagged_vec(&ctx, &exc), &ctx) + .project((Expr::access((0, 0)), Expr::access((1, 0))).into()), + &rt, + ); + first_time = false; + } + + // Inspect the result + expect_collection(&target.complete(&ctx), vec![ + (0, "red".to_string()), + (0, "blue".to_string()), + (1, "red".to_string()), + (1, "blue".to_string()), + ]); +} diff --git a/core/tests/runtime/dataflow/mod.rs b/core/tests/runtime/dataflow/mod.rs index 3d415eb..5339ccd 100644 --- a/core/tests/runtime/dataflow/mod.rs +++ b/core/tests/runtime/dataflow/mod.rs @@ -1,5 +1,6 @@ mod dyn_aggregate; mod dyn_difference; +mod dyn_exclusion; mod dyn_filter; mod dyn_find; mod dyn_foreign_predicate; diff --git a/examples/datalog/boolean_formula_sat.scl b/examples/datalog/boolean_formula_sat.scl new file mode 100644 index 0000000..fd4ceab --- /dev/null +++ b/examples/datalog/boolean_formula_sat.scl @@ -0,0 +1,21 @@ +// Each variable could be assigned either true or false, but not both +rel { assign(x, true); assign(x, false) } = vars(x) + +// There are two variables of interest, A and B +rel vars = {"A", "B"} + +// (A /\ ~A) \/ (B /\ ~B) +rel bf_var = {(1, "A"), (2, "A"), (3, "B"), (4, "B")} +rel bf_not = {(5, 2), (6, 4)} +rel bf_and = {(7, 1, 5), (8, 3, 6)} +rel bf_or = {(9, 7, 8)} +rel bf_root = {9} + +// Evaluation the formula to see if it is satisfiable +rel eval_bf(bf, r) :- bf_var(bf, v), assign(v, r) +rel eval_bf(bf, !r) :- bf_not(bf, c), eval_bf(c, r) +rel eval_bf(bf, lr && rr) :- bf_and(bf, lbf, rbf), eval_bf(lbf, lr), eval_bf(rbf, rr) +rel eval_bf(bf, lr || rr) :- bf_or(bf, lbf, rbf), eval_bf(lbf, lr), eval_bf(rbf, rr) +rel eval(r) :- bf_root(bf), eval_bf(bf, r) + +query eval diff --git a/lib/ram/src/language.rs b/lib/ram/src/language.rs index 463753e..c401186 100644 --- a/lib/ram/src/language.rs +++ b/lib/ram/src/language.rs @@ -57,6 +57,8 @@ pub fn ram_rewrite_rules() -> Vec> { rw!("filter-true"; "(filter ?d true)" => "?d"), rw!("filter-false"; "(filter ?d false)" => "empty"), rw!("project-cascade"; "(project (project ?d ?a) ?b)" => "(project ?d (apply ?b ?a))"), + rw!("product-transpose"; "(product ?a ?b)" => "(project (product ?b ?a) (cons (cons 1 nil) (cons (cons 0 nil) nil))))"), + rw!("join-transpose"; "(join ?a ?b)" => "(project (join ?b ?a) (cons (cons 2 nil) (cons (cons 1 nil) nil))))"), // Tuple level application rewrites rw!("access-nil"; "(apply nil ?a)" => "?a"), @@ -109,10 +111,15 @@ impl CostFunction for RamCostFunction { C: FnMut(Id) -> Self::Cost { let op_cost = match enode { + Ram::Empty => 0, Ram::Filter(_) => 100, Ram::Project(_) => 100, - Ram::Sorted(_) => 100, + Ram::Product(_) => 100, + Ram::Join(_) => 100, + Ram::Sorted(_) => 500, Ram::Apply(_) => 10, + Ram::Cons(_) => 0, + Ram::Nil => 0, _ => 1, }; enode.fold(op_cost, |sum, id| sum + costs(id))