diff --git a/compiler/rustc_data_structures/src/lib.rs b/compiler/rustc_data_structures/src/lib.rs index 7669b78834c3f..4bbf31122e3b2 100644 --- a/compiler/rustc_data_structures/src/lib.rs +++ b/compiler/rustc_data_structures/src/lib.rs @@ -73,6 +73,7 @@ pub mod flock; pub mod fx; pub mod graph; pub mod jobserver; +pub mod logged_unification_table; pub mod macros; pub mod map_in_place; pub mod obligation_forest; @@ -83,6 +84,7 @@ pub mod small_c_str; pub mod snapshot_map; pub mod stable_map; pub mod svh; +pub mod unify_log; pub use ena::snapshot_vec; pub mod sorted_map; pub mod stable_set; @@ -90,6 +92,7 @@ pub mod stable_set; pub mod stable_hasher; mod atomic_ref; pub mod fingerprint; +pub mod modified_set; pub mod profiling; pub mod sharded; pub mod stack; diff --git a/compiler/rustc_data_structures/src/logged_unification_table.rs b/compiler/rustc_data_structures/src/logged_unification_table.rs new file mode 100644 index 0000000000000..902389736cd4c --- /dev/null +++ b/compiler/rustc_data_structures/src/logged_unification_table.rs @@ -0,0 +1,209 @@ +use rustc_index::vec::Idx; + +use crate::modified_set as ms; +use crate::snapshot_vec as sv; +use crate::unify as ut; +use crate::unify_log as ul; + +use ena::undo_log::{Rollback, UndoLogs}; + +pub enum UndoLog { + Relation(sv::UndoLog>), + UnifyLog(ul::Undo), + ModifiedSet(ms::Undo), +} + +impl From>> for UndoLog { + fn from(l: sv::UndoLog>) -> Self { + UndoLog::Relation(l) + } +} + +impl From> for UndoLog { + fn from(l: ul::Undo) -> Self { + UndoLog::UnifyLog(l) + } +} + +impl From> for UndoLog { + fn from(l: ms::Undo) -> Self { + UndoLog::ModifiedSet(l) + } +} + +impl Rollback> for LoggedUnificationTableStorage { + fn reverse(&mut self, undo: UndoLog) { + match undo { + UndoLog::Relation(undo) => self.relations.reverse(undo), + UndoLog::UnifyLog(undo) => self.unify_log.reverse(undo), + UndoLog::ModifiedSet(undo) => self.modified_set.reverse(undo), + } + } +} + +/// Storage for `LoggedUnificationTable` +pub struct LoggedUnificationTableStorage { + relations: ut::UnificationTableStorage, + unify_log: ul::UnifyLog, + modified_set: ms::ModifiedSet, +} + +/// UnificationTableStorage which logs which variables has been unified with a value, allowing watchers +/// to only iterate over the changed variables instead of all variables +pub struct LoggedUnificationTable<'a, K: ut::UnifyKey, I: Idx, L> { + storage: &'a mut LoggedUnificationTableStorage, + undo_log: L, +} + +impl LoggedUnificationTableStorage +where + K: ut::UnifyKey + From, + I: Idx + From, +{ + pub fn new() -> Self { + Self { + relations: Default::default(), + unify_log: ul::UnifyLog::new(), + modified_set: ms::ModifiedSet::new(), + } + } + + pub fn with_log(&mut self, undo_log: L) -> LoggedUnificationTable<'_, K, I, L> { + LoggedUnificationTable { storage: self, undo_log } + } +} + +impl LoggedUnificationTable<'_, K, I, L> +where + K: ut::UnifyKey, + I: Idx, +{ + pub fn len(&self) -> usize { + self.storage.relations.len() + } +} + +impl LoggedUnificationTable<'_, K, I, L> +where + K: ut::UnifyKey + From, + I: Idx + From, + L: UndoLogs> + UndoLogs> + UndoLogs>>, +{ + fn relations( + &mut self, + ) -> ut::UnificationTable, &mut L>> { + ut::UnificationTable::with_log(&mut self.storage.relations, &mut self.undo_log) + } + + pub fn unify(&mut self, a: I, b: I) + where + K::Value: ut::UnifyValue, + { + self.unify_var_var(a, b).unwrap(); + } + + pub fn instantiate(&mut self, vid: I, ty: K::Value) -> K + where + K::Value: ut::UnifyValue, + { + let vid = vid.into(); + let mut relations = self.relations(); + debug_assert!(relations.find(vid) == vid); + relations.union_value(vid, ty); + + vid + } + + pub fn find(&mut self, vid: I) -> K { + self.relations().find(vid) + } + + pub fn unioned(&mut self, l: I, r: I) -> bool { + let mut relations = self.relations(); + relations.find(l) == relations.find(r) + } + + pub fn unify_var_value( + &mut self, + vid: I, + value: K::Value, + ) -> Result<(), ::Error> { + let vid = self.find(vid); + self.relations().unify_var_value(vid, value) + } + + pub fn unify_var_var(&mut self, a: I, b: I) -> Result<(), ::Error> { + let mut relations = self.relations(); + let a = relations.find(a); + let b = relations.find(b); + if a == b { + return Ok(()); + } + + relations.unify_var_var(a, b)?; + + Ok(()) + } + + pub fn union_value(&mut self, vid: I, value: K::Value) + where + K::Value: ut::UnifyValue, + { + let vid = self.find(vid).into(); + self.instantiate(vid, value); + } + + pub fn probe_value(&mut self, vid: I) -> K::Value { + self.relations().probe_value(vid) + } + + #[inline(always)] + pub fn inlined_probe_value(&mut self, vid: I) -> K::Value { + self.relations().inlined_probe_value(vid) + } + + pub fn new_key(&mut self, value: K::Value) -> K { + self.relations().new_key(value) + } + + /// Clears any modifications currently tracked. Usually this can only be done once there are no + /// snapshots active as the modifications may otherwise be needed after a rollback + pub fn clear_modified_set(&mut self) { + self.storage.modified_set.clear(); + } + + /// Registers a watcher on the unifications done in this table + pub fn register_watcher(&mut self) -> ms::Offset { + self.storage.modified_set.register() + } + + /// Deregisters a watcher previously registered in this table + pub fn deregister_watcher(&mut self, offset: ms::Offset) { + self.storage.modified_set.deregister(offset); + } + + /// Watches the variable at `index` allowing any watchers to be notified to unifications with + /// `index` + pub fn watch_variable(&mut self, index: I) { + debug_assert!(index == self.relations().find(index).into()); + self.storage.unify_log.watch_variable(index) + } + + /// Unwatches a previous watch at `index` + pub fn unwatch_variable(&mut self, index: I) { + self.storage.unify_log.unwatch_variable(index) + } + + /// Iterates through all unified variables since the last call to `notify_watcher` + /// passing the unified variable to `f` + pub fn notify_watcher(&mut self, offset: &ms::Offset, mut f: impl FnMut(I)) { + let unify_log = &self.storage.unify_log; + self.storage.modified_set.notify_watcher(&mut self.undo_log, offset, |vid| { + for &unified_vid in unify_log.get(vid) { + f(unified_vid); + } + + f(vid) + }) + } +} diff --git a/compiler/rustc_data_structures/src/modified_set.rs b/compiler/rustc_data_structures/src/modified_set.rs new file mode 100644 index 0000000000000..dcea3b40751cb --- /dev/null +++ b/compiler/rustc_data_structures/src/modified_set.rs @@ -0,0 +1,133 @@ +use std::marker::PhantomData; + +use rustc_index::vec::Idx; + +use ena::undo_log::{Rollback, UndoLogs}; + +#[derive(Copy, Clone, Debug)] +enum UndoInner { + Add, + Drain { index: usize, offset: usize }, +} + +#[derive(Copy, Clone, Debug)] +pub struct Undo(UndoInner, PhantomData); + +/// Tracks which indices have been modified and allows watchers to registered and notified of these +/// changes. +#[derive(Clone, Debug)] +pub struct ModifiedSet { + modified: Vec, + offsets: Vec, +} + +impl Default for ModifiedSet { + fn default() -> Self { + Self { modified: Default::default(), offsets: Vec::new() } + } +} + +impl ModifiedSet { + /// Creates a new `ModifiedSet` + pub fn new() -> Self { + Self::default() + } + + /// Marks `index` as "modified". A subsequent call to `drain` will notify the callback with + /// `index` + pub fn set(&mut self, undo_log: &mut impl UndoLogs>, index: T) { + self.modified.push(index); + undo_log.push(Undo(UndoInner::Add, PhantomData)); + } + + /// Calls `f` with all the indices that have been modified since the last call to + /// `notify_watcher` + pub fn notify_watcher( + &mut self, + undo_log: &mut impl UndoLogs>, + watcher_offset: &Offset, + mut f: impl FnMut(T), + ) { + let offset = &mut self.offsets[watcher_offset.index]; + if *offset < self.modified.len() { + for &index in &self.modified[*offset..] { + f(index); + } + undo_log.push(Undo( + UndoInner::Drain { index: watcher_offset.index, offset: *offset }, + PhantomData, + )); + *offset = self.modified.len(); + } + } + + /// Clears the set of all modifications that have been drained by all watchers + pub fn clear(&mut self) { + let min = self.offsets.iter().copied().min().unwrap_or_else(|| self.modified.len()); + self.modified.drain(..min); + for offset in &mut self.offsets { + *offset -= min; + } + } + + /// Registers a new watcher on this set. + /// + /// NOTE: Watchers must be removed in the reverse order that they were registered + pub fn register(&mut self) -> Offset { + let index = self.offsets.len(); + self.offsets.push(self.modified.len()); + Offset { index, _marker: PhantomData } + } + + /// De-registers a watcher on this set. + /// + /// NOTE: Watchers must be removed in the reverse order that they were registered + pub fn deregister(&mut self, offset: Offset) { + assert_eq!( + offset.index, + self.offsets.len() - 1, + "Watchers must be removed in the reverse order that they were registered" + ); + self.offsets.pop(); + std::mem::forget(offset); + } +} + +impl Rollback> for ModifiedSet { + fn reverse(&mut self, undo: Undo) { + match undo.0 { + UndoInner::Add => { + self.modified.pop(); + } + UndoInner::Drain { index, offset } => { + if let Some(o) = self.offsets.get_mut(index) { + *o = offset; + } + } + } + } +} + +/// A registered offset into a `ModifiedSet`. Tracks how much a watcher has seen so far to avoid +/// being notified of the same event twice. +#[must_use] +pub struct Offset { + index: usize, + _marker: PhantomData, +} + +impl Drop for Offset { + fn drop(&mut self) { + if !std::thread::panicking() { + panic!("Offsets should be deregistered") + } + } +} + +#[must_use] +#[derive(Debug)] +pub struct Snapshot { + modified_len: usize, + undo_log_len: usize, + _marker: PhantomData, +} diff --git a/compiler/rustc_data_structures/src/obligation_forest/mod.rs b/compiler/rustc_data_structures/src/obligation_forest/mod.rs index a5b2df1da5d6d..69ca9f7974f19 100644 --- a/compiler/rustc_data_structures/src/obligation_forest/mod.rs +++ b/compiler/rustc_data_structures/src/obligation_forest/mod.rs @@ -42,7 +42,7 @@ //! now considered to be in error. //! //! When the call to `process_obligations` completes, you get back an `Outcome`, -//! which includes three bits of information: +//! which includes two bits of information: //! //! - `completed`: a list of obligations where processing was fully //! completed without error (meaning that all transitive subobligations @@ -53,13 +53,6 @@ //! all the obligations in `C` have been found completed. //! - `errors`: a list of errors that occurred and associated backtraces //! at the time of error, which can be used to give context to the user. -//! - `stalled`: if true, then none of the existing obligations were -//! *shallowly successful* (that is, no callback returned `Changed(_)`). -//! This implies that all obligations were either errors or returned an -//! ambiguous result, which means that any further calls to -//! `process_obligations` would simply yield back further ambiguous -//! results. This is used by the `FulfillmentContext` to decide when it -//! has reached a steady state. //! //! ### Implementation details //! @@ -74,11 +67,14 @@ use crate::fx::{FxHashMap, FxHashSet}; -use std::cell::Cell; +use std::cell::{Cell, RefCell}; +use std::cmp::Ordering; use std::collections::hash_map::Entry; +use std::collections::BinaryHeap; use std::fmt::Debug; use std::hash; use std::marker::PhantomData; +use std::mem; mod graphviz; @@ -86,19 +82,33 @@ mod graphviz; mod tests; pub trait ForestObligation: Clone + Debug { + /// A key used to avoid evaluating the same obligation twice. type CacheKey: Clone + hash::Hash + Eq + Debug; + /// The variable type used in the obligation when it could not yet be fulfilled. + type Variable: Clone + hash::Hash + Eq + Debug; + /// A type which tracks which variables has been unified. + type WatcherOffset; /// Converts this `ForestObligation` suitable for use as a cache key. /// If two distinct `ForestObligations`s return the same cache key, /// then it must be sound to use the result of processing one obligation - /// (e.g. success for error) for the other obligation + /// (e.g. success for error) for the other obligation. fn as_cache_key(&self) -> Self::CacheKey; + + /// Returns which variables this obligation is currently stalled on. If the slice is empty then + /// the variables stalled on are unknown. + fn stalled_on(&self) -> &[Self::Variable]; } pub trait ObligationProcessor { type Obligation: ForestObligation; type Error: Debug; + fn checked_process_obligation( + &mut self, + obligation: &mut Self::Obligation, + ) -> ProcessResult; + fn process_obligation( &mut self, obligation: &mut Self::Obligation, @@ -115,6 +125,21 @@ pub trait ObligationProcessor { fn process_backedge<'c, I>(&mut self, cycle: I, _marker: PhantomData<&'c Self::Obligation>) where I: Clone + Iterator; + + /// Calls `f` with all the variables that have been unblocked (instantiated) since the last call + /// to `notify_unblocked`. + fn notify_unblocked( + &self, + offset: &::WatcherOffset, + f: impl FnMut(::Variable), + ); + fn register_variable_watcher(&self) -> ::WatcherOffset; + fn deregister_variable_watcher( + &self, + offset: ::WatcherOffset, + ); + fn watch_variable(&self, var: ::Variable); + fn unwatch_variable(&self, var: ::Variable); } /// The result type used by `process_obligation`. @@ -131,26 +156,45 @@ struct ObligationTreeId(usize); type ObligationTreeIdGenerator = std::iter::Map, fn(usize) -> ObligationTreeId>; +/// `usize` indices are used here and throughout this module, rather than +/// `rustc_index::newtype_index!` indices, because this code is hot enough +/// that the `u32`-to-`usize` conversions that would be required are +/// significant, and space considerations are not important. +type NodeIndex = usize; + +enum CacheState { + Active(NodeIndex), + Done, +} + pub struct ObligationForest { /// The list of obligations. In between calls to `process_obligations`, /// this list only contains nodes in the `Pending` or `Waiting` state. - /// - /// `usize` indices are used here and throughout this module, rather than - /// `rustc_index::newtype_index!` indices, because this code is hot enough - /// that the `u32`-to-`usize` conversions that would be required are - /// significant, and space considerations are not important. nodes: Vec>, - /// A cache of predicates that have been successfully completed. - done_cache: FxHashSet, + /// Nodes must be processed in the order that they were added so we give each node a unique, + /// number allowing them to be ordered when processing them. + node_number: u32, + + /// Stores the indices of the nodes currently in the pending state. + pending_nodes: Vec, + + /// Stores the indices of the nodes currently in the success or waiting states. + /// Can also contain `Done` or `Error` nodes as `process_cycles` does not remove a node + /// immediately but instead upon the next time that node is processed. + success_or_waiting_nodes: Vec, + + /// Stores the indices of the nodes currently in the error or done states. + error_or_done_nodes: RefCell>, + + /// Nodes that have been removed and are ready to be reused (pure optimization to reuse + /// allocations) + dead_nodes: Vec, /// A cache of the nodes in `nodes`, indexed by predicate. Unfortunately, /// its contents are not guaranteed to match those of `nodes`. See the /// comments in `process_obligation` for details. - active_cache: FxHashMap, - - /// A vector reused in compress() and find_cycles_from_node(), to avoid allocating new vectors. - reused_node_vec: Vec, + active_cache: FxHashMap, obligation_tree_id_generator: ObligationTreeIdGenerator, @@ -162,38 +206,133 @@ pub struct ObligationForest { /// /// [details]: https://github.com/rust-lang/rust/pull/53255#issuecomment-421184780 error_cache: FxHashMap>, + + /// Stores which nodes would be unblocked once `O::Variable` is unified. + stalled_on: FxHashMap>, + + /// Stores the node indices that are unblocked and should be processed at the next opportunity. + unblocked: BinaryHeap, + + /// Stores nodes which should be processed on the next iteration since the variables they are + /// actually blocked on are unknown. + stalled_on_unknown: Vec, + + /// The offset that this `ObligationForest` has registered. Should be de-registered before + /// dropping this forest. + /// + watcher_offset: Option, + /// We do not want to process any further obligations after the offset has been deregistered as that could mean unified variables are lost, leading to typecheck failures. + /// So we mark this as done and panic if a caller tries to resume processing. + done: bool, + /// Reusable vector for storing unblocked nodes whose watch should be removed. + temp_unblocked_nodes: Vec, + + reused_node_vec: Vec, +} + +/// Helper struct for use with `BinaryHeap` to process nodes in the order that they were added to +/// the forest. +struct Unblocked { + index: NodeIndex, + order: u32, +} + +impl PartialEq for Unblocked { + fn eq(&self, other: &Self) -> bool { + self.order == other.order + } +} +impl Eq for Unblocked {} +impl PartialOrd for Unblocked { + fn partial_cmp(&self, other: &Self) -> Option { + other.order.partial_cmp(&self.order) + } +} +impl Ord for Unblocked { + fn cmp(&self, other: &Self) -> Ordering { + other.order.cmp(&self.order) + } } #[derive(Debug)] -struct Node { +struct Node { obligation: O, state: Cell, + /// A predicate (and its key) can change during processing. If it does we need to register the + /// old predicate so that we can remove or mark it as done if this node errors or is done. + alternative_predicates: Vec, + /// Obligations that depend on this obligation for their completion. They /// must all be in a non-pending state. - dependents: Vec, + dependents: Vec, + + /// Obligations that this obligation depends on for their completion. + reverse_dependents: Vec, /// If true, `dependents[0]` points to a "parent" node, which requires /// special treatment upon error but is otherwise treated the same. /// (It would be more idiomatic to store the parent node in a separate - /// `Option` field, but that slows down the common case of + /// `Option` field, but that slows down the common case of /// iterating over the parent and other descendants together.) has_parent: bool, /// Identifier of the obligation tree to which this node belongs. obligation_tree_id: ObligationTreeId, + + /// Nodes must be processed in the order that they were added so we give each node a unique + /// number allowing them to be ordered when processing them. + node_number: u32, } -impl Node { - fn new(parent: Option, obligation: O, obligation_tree_id: ObligationTreeId) -> Node { +impl Node +where + O: ForestObligation, +{ + fn new( + parent: Option, + obligation: O, + obligation_tree_id: ObligationTreeId, + node_number: u32, + ) -> Node { Node { obligation, state: Cell::new(NodeState::Pending), + alternative_predicates: vec![], dependents: if let Some(parent_index) = parent { vec![parent_index] } else { vec![] }, + reverse_dependents: vec![], has_parent: parent.is_some(), obligation_tree_id, + node_number, } } + + /// Initializes a node, reusing the existing allocations. Used when removing a node from the + /// dead_nodes list + fn reinit( + &mut self, + parent: Option, + obligation: O, + obligation_tree_id: ObligationTreeId, + node_number: u32, + ) { + self.obligation = obligation; + debug_assert!( + self.state.get() == NodeState::Done || self.state.get() == NodeState::Error, + "{:?}", + self.state + ); + self.state.set(NodeState::Pending); + self.alternative_predicates.clear(); + self.dependents.clear(); + self.reverse_dependents.clear(); + if let Some(parent_index) = parent { + self.dependents.push(parent_index); + } + self.has_parent = parent.is_some(); + self.obligation_tree_id = obligation_tree_id; + self.node_number = node_number; + } } /// The state of one node in some tree within the forest. This represents the @@ -251,57 +390,22 @@ enum NodeState { Error, } -/// This trait allows us to have two different Outcome types: -/// - the normal one that does as little as possible -/// - one for tests that does some additional work and checking -pub trait OutcomeTrait { - type Error; - type Obligation; - - fn new() -> Self; - fn mark_not_stalled(&mut self); - fn is_stalled(&self) -> bool; - fn record_completed(&mut self, outcome: &Self::Obligation); - fn record_error(&mut self, error: Self::Error); -} - #[derive(Debug)] pub struct Outcome { + /// Obligations that were completely evaluated, including all + /// (transitive) subobligations. Only computed if requested. + pub completed: Option>, + /// Backtrace of obligations that were found to be in error. pub errors: Vec>, - - /// If true, then we saw no successful obligations, which means - /// there is no point in further iteration. This is based on the - /// assumption that when trait matching returns `Error` or - /// `Unchanged`, those results do not affect environmental - /// inference state. (Note that if we invoke `process_obligations` - /// with no pending obligations, stalled will be true.) - pub stalled: bool, } -impl OutcomeTrait for Outcome { - type Error = Error; - type Obligation = O; - - fn new() -> Self { - Self { stalled: true, errors: vec![] } - } - - fn mark_not_stalled(&mut self) { - self.stalled = false; - } - - fn is_stalled(&self) -> bool { - self.stalled - } - - fn record_completed(&mut self, _outcome: &Self::Obligation) { - // do nothing - } - - fn record_error(&mut self, error: Self::Error) { - self.errors.push(error) - } +/// Should `process_obligations` compute the `Outcome::completed` field of its +/// result? +#[derive(PartialEq, Copy, Clone)] +pub enum DoCompleted { + No, + Yes, } #[derive(Debug, PartialEq, Eq)] @@ -314,14 +418,36 @@ impl ObligationForest { pub fn new() -> ObligationForest { ObligationForest { nodes: vec![], - done_cache: Default::default(), + pending_nodes: vec![], + success_or_waiting_nodes: vec![], + error_or_done_nodes: RefCell::new(vec![]), + dead_nodes: vec![], active_cache: Default::default(), - reused_node_vec: vec![], obligation_tree_id_generator: (0..).map(ObligationTreeId), + node_number: 0, error_cache: Default::default(), + stalled_on: Default::default(), + unblocked: Default::default(), + stalled_on_unknown: Default::default(), + temp_unblocked_nodes: Default::default(), + watcher_offset: None, + done: false, + reused_node_vec: Vec::new(), } } + /// Returns the `WatcherOffset` regsitered with the notification table. See the field + /// `ObligationForest::watcher_offset`. + pub fn watcher_offset(&self) -> Option<&O::WatcherOffset> { + self.watcher_offset.as_ref() + } + + /// Removes the watcher_offset, allowing it to be deregistered + pub fn take_watcher_offset(&mut self) -> Option { + self.done = true; + self.watcher_offset.take() + } + /// Returns the total number of nodes in the forest that have not /// yet been fully resolved. pub fn len(&self) -> usize { @@ -335,15 +461,26 @@ impl ObligationForest { } // Returns Err(()) if we already know this obligation failed. - fn register_obligation_at(&mut self, obligation: O, parent: Option) -> Result<(), ()> { - if self.done_cache.contains(&obligation.as_cache_key()) { - debug!("register_obligation_at: ignoring already done obligation: {:?}", obligation); - return Ok(()); - } - - match self.active_cache.entry(obligation.as_cache_key()) { + fn register_obligation_at( + &mut self, + obligation: O, + parent: Option, + ) -> Result<(), ()> { + debug_assert!(obligation.stalled_on().is_empty()); + match self.active_cache.entry(obligation.as_cache_key().clone()) { Entry::Occupied(o) => { - let node = &mut self.nodes[*o.get()]; + let index = match o.get() { + CacheState::Active(index) => *index, + CacheState::Done => { + debug!( + "register_obligation_at: ignoring already done obligation: {:?}", + obligation + ); + return Ok(()); + } + }; + let node = &mut self.nodes[index]; + let state = node.state.get(); if let Some(parent_index) = parent { // If the node is already in `active_cache`, it has already // had its chance to be marked with a parent. So if it's @@ -351,9 +488,10 @@ impl ObligationForest { // dependents as a non-parent. if !node.dependents.contains(&parent_index) { node.dependents.push(parent_index); + self.nodes[parent_index].reverse_dependents.push(index); } } - if let NodeState::Error = node.state.get() { Err(()) } else { Ok(()) } + if let NodeState::Error = state { Err(()) } else { Ok(()) } } Entry::Vacant(v) => { let obligation_tree_id = match parent { @@ -371,9 +509,34 @@ impl ObligationForest { if already_failed { Err(()) } else { - let new_index = self.nodes.len(); - v.insert(new_index); - self.nodes.push(Node::new(parent, obligation, obligation_tree_id)); + // Retrieves a fresh number for the new node so that each node are processed in the + // order that they were created + let node_number = self.node_number; + self.node_number += 1; + + // If we have a dead node we can reuse it and it's associated allocations, + // otherwise allocate a new node + let new_index = if let Some(new_index) = self.dead_nodes.pop() { + let node = &mut self.nodes[new_index]; + node.reinit(parent, obligation, obligation_tree_id, node_number); + new_index + } else { + let new_index = self.nodes.len(); + self.nodes.push(Node::new( + parent, + obligation, + obligation_tree_id, + node_number, + )); + new_index + }; + if let Some(parent_index) = parent { + self.nodes[parent_index].reverse_dependents.push(new_index); + } + + self.pending_nodes.push(new_index); + self.unblocked.push(Unblocked { index: new_index, order: node_number }); + v.insert(CacheState::Active(new_index)); Ok(()) } } @@ -383,14 +546,14 @@ impl ObligationForest { /// Converts all remaining obligations to the given error. pub fn to_errors(&mut self, error: E) -> Vec> { let errors = self - .nodes + .pending_nodes .iter() - .enumerate() - .filter(|(_index, node)| node.state.get() == NodeState::Pending) - .map(|(index, _node)| Error { error: error.clone(), backtrace: self.error_at(index) }) + .filter(|&&index| self.nodes[index].state.get() == NodeState::Pending) + .map(|&index| Error { error: error.clone(), backtrace: self.error_at(index) }) .collect(); - self.compress(|_| assert!(false)); + let successful_obligations = self.compress(DoCompleted::Yes); + assert!(successful_obligations.unwrap().is_empty()); errors } @@ -399,14 +562,17 @@ impl ObligationForest { where F: Fn(&O) -> P, { - self.nodes + self.pending_nodes .iter() - .filter(|node| node.state.get() == NodeState::Pending) + .filter_map(|&index| { + let node = &self.nodes[index]; + if node.state.get() == NodeState::Pending { Some(node) } else { None } + }) .map(|node| f(&node.obligation)) .collect() } - fn insert_into_error_cache(&mut self, index: usize) { + fn insert_into_error_cache(&mut self, index: NodeIndex) { let node = &self.nodes[index]; self.error_cache .entry(node.obligation_tree_id) @@ -418,12 +584,158 @@ impl ObligationForest { /// be called in a loop until `outcome.stalled` is false. /// /// This _cannot_ be unrolled (presently, at least). - pub fn process_obligations(&mut self, processor: &mut P) -> OUT + pub fn process_obligations

( + &mut self, + processor: &mut P, + do_completed: DoCompleted, + ) -> Outcome + where + P: ObligationProcessor, + { + if self.watcher_offset.is_none() { + assert!(!self.done); + if false && self.nodes.len() > 100 { + self.watcher_offset = Some(processor.register_variable_watcher()); + } + if let Some(outcome) = self.process_obligations_simple(processor, do_completed) { + return outcome; + } + } + let mut errors = vec![]; + let mut stalled = true; + + self.unblock_nodes(processor); + + let mut made_progress_this_iteration = true; + while made_progress_this_iteration { + made_progress_this_iteration = false; + let nodes = &self.nodes; + self.unblocked.extend( + self.stalled_on_unknown + .drain(..) + .map(|index| Unblocked { index, order: nodes[index].node_number }), + ); + + while let Some(Unblocked { index, .. }) = self.unblocked.pop() { + // Skip any duplicates since we only need to processes the node once + if self.unblocked.peek().map(|u| u.index) == Some(index) { + continue; + } + + let node = &mut self.nodes[index]; + + if node.state.get() != NodeState::Pending { + continue; + } + + // One of the variables we stalled on unblocked us. If the node were blocked on other + // variables as well then remove those stalls. If the node is still stalled on one of + // those variables after `process_obligation` it will simply be added back to + // `self.stalled_on` + let stalled_on = node.obligation.stalled_on(); + if stalled_on.len() > 1 { + for var in stalled_on { + match self.stalled_on.entry(var.clone()) { + Entry::Vacant(_) => (), + Entry::Occupied(mut entry) => { + let nodes = entry.get_mut(); + if let Some(i) = nodes.iter().position(|x| *x == index) { + nodes.swap_remove(i); + } + if nodes.is_empty() { + processor.unwatch_variable(var.clone()); + entry.remove(); + } + } + } + } + } + + // `processor.process_obligation` can modify the predicate within + // `node.obligation`, and that predicate is the key used for + // `self.active_cache`. This means that `self.active_cache` can get + // out of sync with `nodes`. It's not very common, but it does + // happen, and code in `compress` has to allow for it. + let before = node.obligation.as_cache_key(); + let result = processor.process_obligation(&mut node.obligation); + let after = node.obligation.as_cache_key(); + if before != after { + node.alternative_predicates.push(before); + } + + self.unblock_nodes(processor); + let node = &mut self.nodes[index]; + match result { + ProcessResult::Unchanged => { + let stalled_on = node.obligation.stalled_on(); + if stalled_on.is_empty() { + // We stalled but the variables that caused it are unknown so we run + // `index` again at the next opportunity + self.stalled_on_unknown.push(index); + } else { + // Register every variable that we stalled on + for var in stalled_on { + self.stalled_on + .entry(var.clone()) + .or_insert_with(|| { + processor.watch_variable(var.clone()); + Vec::new() + }) + .push(index); + } + } + // No change in state. + } + ProcessResult::Changed(children) => { + made_progress_this_iteration = true; + // We are not (yet) stalled. + stalled = false; + node.state.set(NodeState::Success); + self.success_or_waiting_nodes.push(index); + + for child in children { + let st = self.register_obligation_at(child, Some(index)); + if let Err(()) = st { + // Error already reported - propagate it + // to our node. + self.error_at(index); + } + } + } + ProcessResult::Error(err) => { + made_progress_this_iteration = true; + stalled = false; + errors.push(Error { error: err, backtrace: self.error_at(index) }); + } + } + } + } + + if stalled { + // There's no need to perform marking, cycle processing and compression when nothing + // changed. + return Outcome { + completed: if do_completed == DoCompleted::Yes { Some(vec![]) } else { None }, + errors, + }; + } + + self.mark_successes(); + self.process_cycles(processor); + let completed = self.compress(do_completed); + Outcome { completed, errors } + } + + fn process_obligations_simple

( + &mut self, + processor: &mut P, + do_completed: DoCompleted, + ) -> Option> where P: ObligationProcessor, - OUT: OutcomeTrait>, { - let mut outcome = OUT::new(); + let mut errors = vec![]; + let mut stalled = true; // Note that the loop body can append new nodes, and those new nodes // will then be processed by subsequent iterations of the loop. @@ -433,63 +745,157 @@ impl ObligationForest { // `for index in 0..self.nodes.len() { ... }` because the range would // be computed with the initial length, and we would miss the appended // nodes. Therefore we use a `while` loop. - let mut index = 0; - while let Some(node) = self.nodes.get_mut(index) { - // `processor.process_obligation` can modify the predicate within - // `node.obligation`, and that predicate is the key used for - // `self.active_cache`. This means that `self.active_cache` can get - // out of sync with `nodes`. It's not very common, but it does - // happen, and code in `compress` has to allow for it. - if node.state.get() != NodeState::Pending { - index += 1; - continue; - } + let mut completed = vec![]; + loop { + let mut i = 0; + let mut made_progress_this_iteration = false; + while let Some(&index) = self.pending_nodes.get(i) { + let node = &mut self.nodes[index]; + // `processor.process_obligation` can modify the predicate within + // `node.obligation`, and that predicate is the key used for + // `self.active_cache`. This means that `self.active_cache` can get + // out of sync with `nodes`. It's not very common, but it does + // happen, and code in `compress` has to allow for it. + if node.state.get() != NodeState::Pending { + i += 1; + continue; + } - match processor.process_obligation(&mut node.obligation) { - ProcessResult::Unchanged => { - // No change in state. + // `processor.process_obligation` can modify the predicate within + // `node.obligation`, and that predicate is the key used for + // `self.active_cache`. This means that `self.active_cache` can get + // out of sync with `nodes`. It's not very common, but it does + // happen, and code in `compress` has to allow for it. + let before = node.obligation.as_cache_key(); + let result = processor.checked_process_obligation(&mut node.obligation); + let after = node.obligation.as_cache_key(); + if before != after { + node.alternative_predicates.push(before); } - ProcessResult::Changed(children) => { - // We are not (yet) stalled. - outcome.mark_not_stalled(); - node.state.set(NodeState::Success); - - for child in children { - let st = self.register_obligation_at(child, Some(index)); - if let Err(()) = st { - // Error already reported - propagate it - // to our node. - self.error_at(index); + + match result { + ProcessResult::Unchanged => { + // No change in state. + if self.watcher_offset.is_some() { + let stalled_on = node.obligation.stalled_on(); + if stalled_on.is_empty() { + // We stalled but the variables that caused it are unknown so we run + // `index` again at the next opportunity + self.stalled_on_unknown.push(index); + } else { + // Register every variable that we stalled on + for var in stalled_on { + self.stalled_on + .entry(var.clone()) + .or_insert_with(|| { + processor.watch_variable(var.clone()); + Vec::new() + }) + .push(index); + } + } } } + ProcessResult::Changed(children) => { + // We are not (yet) stalled. + stalled = false; + node.state.set(NodeState::Success); + made_progress_this_iteration = true; + self.success_or_waiting_nodes.push(index); + + for child in children { + let st = self.register_obligation_at(child, Some(index)); + if let Err(()) = st { + // Error already reported - propagate it + // to our node. + self.error_at(index); + } + } + } + ProcessResult::Error(err) => { + stalled = false; + errors.push(Error { error: err, backtrace: self.error_at(index) }); + } } - ProcessResult::Error(err) => { - outcome.mark_not_stalled(); - outcome.record_error(Error { error: err, backtrace: self.error_at(index) }); - } + i += 1; + } + + if stalled { + // There's no need to perform marking, cycle processing and compression when nothing + // changed. + return Some(Outcome { + completed: if do_completed == DoCompleted::Yes { Some(vec![]) } else { None }, + errors, + }); + } + + if !made_progress_this_iteration { + break; + } + + if self.watcher_offset.is_some() { + return None; } - index += 1; - } - // There's no need to perform marking, cycle processing and compression when nothing - // changed. - if !outcome.is_stalled() { self.mark_successes(); self.process_cycles(processor); - self.compress(|obl| outcome.record_completed(obl)); + if let Some(mut c) = self.compress(do_completed) { + completed.append(&mut c); + } } - outcome + Some(Outcome { + completed: if do_completed == DoCompleted::Yes { Some(completed) } else { None }, + errors, + }) + } + + /// Checks which nodes have been unblocked since the last time this was called. All nodes that + /// were unblocked are added to the `unblocked` queue and all watches associated with the + /// variables blocking those nodes are deregistered (since they are now instantiated, they will + /// neither block a node, nor be instantiated again) + fn unblock_nodes

(&mut self, processor: &mut P) + where + P: ObligationProcessor, + { + let nodes = &mut self.nodes; + let stalled_on = &mut self.stalled_on; + let unblocked = &mut self.unblocked; + let temp_unblocked_nodes = &mut self.temp_unblocked_nodes; + temp_unblocked_nodes.clear(); + processor.notify_unblocked(self.watcher_offset.as_ref().unwrap(), |var| { + if let Some(unblocked_nodes) = stalled_on.remove(&var) { + for node_index in unblocked_nodes { + let node = &nodes[node_index]; + debug_assert!( + node.state.get() == NodeState::Pending, + "Unblocking non-pending2: {:?}", + node.obligation + ); + unblocked.push(Unblocked { index: node_index, order: node.node_number }); + } + temp_unblocked_nodes.push(var); + } + }); + for var in temp_unblocked_nodes.drain(..) { + processor.unwatch_variable(var); + } } /// Returns a vector of obligations for `p` and all of its /// ancestors, putting them into the error state in the process. - fn error_at(&self, mut index: usize) -> Vec { - let mut error_stack: Vec = vec![]; + fn error_at(&self, mut index: NodeIndex) -> Vec { + let mut error_stack: Vec = vec![]; let mut trace = vec![]; + let mut error_or_done_nodes = self.error_or_done_nodes.borrow_mut(); + loop { let node = &self.nodes[index]; + match node.state.get() { + NodeState::Error | NodeState::Done => (), // Already added to `error_or_done_nodes` + _ => error_or_done_nodes.push(index), + } node.state.set(NodeState::Error); trace.push(node.obligation.clone()); if node.has_parent { @@ -517,9 +923,10 @@ impl ObligationForest { /// Mark all `Waiting` nodes as `Success`, except those that depend on a /// pending node. - fn mark_successes(&self) { + fn mark_successes(&mut self) { // Convert all `Waiting` nodes to `Success`. - for node in &self.nodes { + for &index in &self.success_or_waiting_nodes { + let node = &self.nodes[index]; if node.state.get() == NodeState::Waiting { node.state.set(NodeState::Success); } @@ -527,12 +934,18 @@ impl ObligationForest { // Convert `Success` nodes that depend on a pending node back to // `Waiting`. - for node in &self.nodes { + let mut pending_nodes = mem::take(&mut self.pending_nodes); + pending_nodes.retain(|&index| { + let node = &self.nodes[index]; if node.state.get() == NodeState::Pending { // This call site is hot. self.inlined_mark_dependents_as_waiting(node); + true + } else { + false } - } + }); + self.pending_nodes = pending_nodes; } // This always-inlined function is for the hot call site. @@ -564,8 +977,11 @@ impl ObligationForest { where P: ObligationProcessor, { - let mut stack = std::mem::take(&mut self.reused_node_vec); - for (index, node) in self.nodes.iter().enumerate() { + let mut stack = mem::take(&mut self.reused_node_vec); + + let success_or_waiting_nodes = mem::take(&mut self.success_or_waiting_nodes); + for &index in &success_or_waiting_nodes { + let node = &self.nodes[index]; // For some benchmarks this state test is extremely hot. It's a win // to handle the no-op cases immediately to avoid the cost of the // function call. @@ -573,13 +989,18 @@ impl ObligationForest { self.find_cycles_from_node(&mut stack, processor, index); } } + self.success_or_waiting_nodes = success_or_waiting_nodes; debug_assert!(stack.is_empty()); self.reused_node_vec = stack; } - fn find_cycles_from_node

(&self, stack: &mut Vec, processor: &mut P, index: usize) - where + fn find_cycles_from_node

( + &self, + stack: &mut Vec, + processor: &mut P, + index: NodeIndex, + ) where P: ObligationProcessor, { let node = &self.nodes[index]; @@ -592,11 +1013,12 @@ impl ObligationForest { } stack.pop(); node.state.set(NodeState::Done); + self.error_or_done_nodes.borrow_mut().push(index); } Some(rpos) => { // Cycle detected. processor.process_backedge( - stack[rpos..].iter().map(GetObligation(&self.nodes)), + stack[rpos..].iter().map(|i| &self.nodes[*i].obligation), PhantomData, ); } @@ -604,120 +1026,81 @@ impl ObligationForest { } } - /// Compresses the vector, removing all popped nodes. This adjusts the - /// indices and hence invalidates any outstanding indices. `process_cycles` - /// must be run beforehand to remove any cycles on `Success` nodes. + /// Compresses the forest, moving all nodes marked `Done` or `Error` into `dead_nodes` for later reuse + /// `process_cycles` must be run beforehand to remove any cycles on `Success` nodes. #[inline(never)] - fn compress(&mut self, mut outcome_cb: impl FnMut(&O)) { - let orig_nodes_len = self.nodes.len(); - let mut node_rewrites: Vec<_> = std::mem::take(&mut self.reused_node_vec); - debug_assert!(node_rewrites.is_empty()); - node_rewrites.extend(0..orig_nodes_len); - let mut dead_nodes = 0; - - // Move removable nodes to the end, preserving the order of the - // remaining nodes. - // - // LOOP INVARIANT: - // self.nodes[0..index - dead_nodes] are the first remaining nodes - // self.nodes[index - dead_nodes..index] are all dead - // self.nodes[index..] are unchanged - for index in 0..orig_nodes_len { - let node = &self.nodes[index]; - match node.state.get() { - NodeState::Pending | NodeState::Waiting => { - if dead_nodes > 0 { - self.nodes.swap(index, index - dead_nodes); - node_rewrites[index] -= dead_nodes; + fn compress(&mut self, do_completed: DoCompleted) -> Option> { + let mut removed_done_obligations: Vec = vec![]; + + // Compress the forest by removing any nodes marked as error or done + let mut error_or_done_nodes = mem::take(self.error_or_done_nodes.get_mut()); + for &index in &error_or_done_nodes { + let node = &mut self.nodes[index]; + + // Remove this node from all the nodes that depends on it + let reverse_dependents = mem::take(&mut node.reverse_dependents); + for &reverse_index in &reverse_dependents { + let reverse_node = &mut self.nodes[reverse_index]; + + if let Some(i) = reverse_node.dependents.iter().position(|x| *x == index) { + reverse_node.dependents.swap_remove(i); + if i == 0 { + reverse_node.has_parent = false; } } + } + let node = &mut self.nodes[index]; + node.reverse_dependents = reverse_dependents; + + match node.state.get() { NodeState::Done => { - // This lookup can fail because the contents of - // `self.active_cache` are not guaranteed to match those of - // `self.nodes`. See the comment in `process_obligation` - // for more details. - if let Some((predicate, _)) = - self.active_cache.remove_entry(&node.obligation.as_cache_key()) - { - self.done_cache.insert(predicate); - } else { - self.done_cache.insert(node.obligation.as_cache_key().clone()); + // Mark as done + *self + .active_cache + .entry(node.obligation.as_cache_key()) + .or_insert(CacheState::Done) = CacheState::Done; + // If the node's predicate changed at some point we mark all its alternate + // predicates as done as well + for alt in node.alternative_predicates.drain(..) { + *self.active_cache.entry(alt).or_insert(CacheState::Done) = + CacheState::Done; + } + + if do_completed == DoCompleted::Yes { + // Extract the success stories. + removed_done_obligations.push(node.obligation.clone()); } - // Extract the success stories. - outcome_cb(&node.obligation); - node_rewrites[index] = orig_nodes_len; - dead_nodes += 1; + + // Store the node so it and its allocations can be used when another node is + // allocated + self.dead_nodes.push(index); } NodeState::Error => { // We *intentionally* remove the node from the cache at this point. Otherwise // tests must come up with a different type on every type error they // check against. self.active_cache.remove(&node.obligation.as_cache_key()); - self.insert_into_error_cache(index); - node_rewrites[index] = orig_nodes_len; - dead_nodes += 1; - } - NodeState::Success => unreachable!(), - } - } - - if dead_nodes > 0 { - // Remove the dead nodes and rewrite indices. - self.nodes.truncate(orig_nodes_len - dead_nodes); - self.apply_rewrites(&node_rewrites); - } - - node_rewrites.truncate(0); - self.reused_node_vec = node_rewrites; - } - - fn apply_rewrites(&mut self, node_rewrites: &[usize]) { - let orig_nodes_len = node_rewrites.len(); - - for node in &mut self.nodes { - let mut i = 0; - while let Some(dependent) = node.dependents.get_mut(i) { - let new_index = node_rewrites[*dependent]; - if new_index >= orig_nodes_len { - node.dependents.swap_remove(i); - if i == 0 && node.has_parent { - // We just removed the parent. - node.has_parent = false; + // If the node's predicate changed at some point we remove all its alternate + // predicates as well + for alt in &node.alternative_predicates { + self.active_cache.remove(alt); } - } else { - *dependent = new_index; - i += 1; + self.insert_into_error_cache(index); + self.dead_nodes.push(index); } + NodeState::Pending | NodeState::Waiting | NodeState::Success => unreachable!(), } } - - // This updating of `self.active_cache` is necessary because the - // removal of nodes within `compress` can fail. See above. - self.active_cache.retain(|_predicate, index| { - let new_index = node_rewrites[*index]; - if new_index >= orig_nodes_len { - false - } else { - *index = new_index; - true - } + error_or_done_nodes.clear(); + *self.error_or_done_nodes.get_mut() = error_or_done_nodes; + + let nodes = &self.nodes; + self.success_or_waiting_nodes.retain(|&index| match nodes[index].state.get() { + NodeState::Waiting | NodeState::Success => true, + NodeState::Done | NodeState::Error => false, + NodeState::Pending => unreachable!(), }); - } -} - -// I need a Clone closure. -#[derive(Clone)] -struct GetObligation<'a, O>(&'a [Node]); - -impl<'a, 'b, O> FnOnce<(&'b usize,)> for GetObligation<'a, O> { - type Output = &'a O; - extern "rust-call" fn call_once(self, args: (&'b usize,)) -> &'a O { - &self.0[*args.0].obligation - } -} -impl<'a, 'b, O> FnMut<(&'b usize,)> for GetObligation<'a, O> { - extern "rust-call" fn call_mut(&mut self, args: (&'b usize,)) -> &'a O { - &self.0[*args.0].obligation + if do_completed == DoCompleted::Yes { Some(removed_done_obligations) } else { None } } } diff --git a/compiler/rustc_data_structures/src/obligation_forest/tests.rs b/compiler/rustc_data_structures/src/obligation_forest/tests.rs index 371c62c063fa7..0a0a8a4b11296 100644 --- a/compiler/rustc_data_structures/src/obligation_forest/tests.rs +++ b/compiler/rustc_data_structures/src/obligation_forest/tests.rs @@ -5,10 +5,16 @@ use std::marker::PhantomData; impl<'a> super::ForestObligation for &'a str { type CacheKey = &'a str; + type Variable = (); + type WatcherOffset = (); fn as_cache_key(&self) -> Self::CacheKey { self } + + fn stalled_on(&self) -> &[Self::Variable] { + &[] + } } struct ClosureObligationProcessor { @@ -66,7 +72,7 @@ where impl ObligationProcessor for ClosureObligationProcessor where - O: super::ForestObligation + fmt::Debug, + O: super::ForestObligation + fmt::Debug, E: fmt::Debug, OF: FnMut(&mut O) -> ProcessResult, BF: FnMut(&[O]), @@ -74,6 +80,13 @@ where type Obligation = O; type Error = E; + fn checked_process_obligation( + &mut self, + obligation: &mut Self::Obligation, + ) -> ProcessResult { + (self.process_obligation)(obligation) + } + fn process_obligation( &mut self, obligation: &mut Self::Obligation, @@ -86,6 +99,21 @@ where I: Clone + Iterator, { } + + fn notify_unblocked( + &self, + _offset: &::WatcherOffset, + _f: impl FnMut(::Variable), + ) { + } + fn register_variable_watcher(&self) -> ::WatcherOffset {} + fn deregister_variable_watcher( + &self, + _offset: ::WatcherOffset, + ) { + } + fn watch_variable(&self, _var: ::Variable) {} + fn unwatch_variable(&self, _var: ::Variable) {} } #[test] diff --git a/compiler/rustc_data_structures/src/unify_log.rs b/compiler/rustc_data_structures/src/unify_log.rs new file mode 100644 index 0000000000000..1384dc5395213 --- /dev/null +++ b/compiler/rustc_data_structures/src/unify_log.rs @@ -0,0 +1,135 @@ +use rustc_index::vec::{Idx, IndexVec}; + +use ena::undo_log::{Rollback, UndoLogs}; + +pub enum Undo { + Move { index: T, old: usize }, + Extend { group_index: usize, len: usize }, + NewGroup { index: T }, +} + +/// Tracks which variables (represented by indices) that has been unified with eachother. +/// Since there is often only a few variables that are interesting one must call `watch_variable` +/// on any variable record unifications with. Used in conjuction with a `ModifiedSet` to accurately +/// track which variables has been instantiated. +/// +/// NOTE: The methods on this expect and only work correctly if a root variable from +/// an `UnificationTable` is provided. +pub struct UnifyLog { + unified_vars: IndexVec, + groups: Vec>, + reference_counts: IndexVec, +} + +fn pick2_mut(self_: &mut [T], a: I, b: I) -> (&mut T, &mut T) { + let (ai, bi) = (a.index(), b.index()); + assert!(ai != bi); + + if ai < bi { + let (c1, c2) = self_.split_at_mut(bi); + (&mut c1[ai], &mut c2[0]) + } else { + let (c1, c2) = self_.split_at_mut(ai); + (&mut c2[0], &mut c1[bi]) + } +} + +impl UnifyLog { + /// Returns a new `UnifyLog` + pub fn new() -> Self { + UnifyLog { + unified_vars: IndexVec::new(), + groups: Vec::new(), + reference_counts: IndexVec::new(), + } + } + + /// Logs that `root` were unified with `other`. Allowing all variables that were unified with + /// root to be returned by `get` (if those variables are watched) + pub fn unify(&mut self, undo_log: &mut impl UndoLogs>, root: T, other: T) { + if !self.needs_log(other) { + return; + } + self.unified_vars.ensure_contains_elem(root.max(other), usize::max_value); + let mut root_group = self.unified_vars[root]; + let other_group = self.unified_vars[other]; + + match (root_group, other_group) { + (usize::MAX, usize::MAX) => { + // Neither variable is part of a group, create a new one at the root and associate + // other + root_group = self.groups.len(); + self.unified_vars[root] = root_group; + self.groups.push(vec![other]); + undo_log.push(Undo::NewGroup { index: root }); + } + (usize::MAX, _) => { + // `other` has a group, point `root` to it and associate other + let group = &mut self.unified_vars[root]; + undo_log.push(Undo::Move { index: root, old: *group }); + *group = other_group; + self.groups[other_group].push(other); + } + (_, usize::MAX) => { + // `root` hasa group, just associate `other` + let root_vec = &mut self.groups[root_group]; + undo_log.push(Undo::Extend { group_index: root_group, len: root_vec.len() }); + root_vec.push(other); + } + _ => { + // Both variables has their own groups, associate all of `other` to root + let (root_vec, other_vec) = pick2_mut(&mut self.groups, root_group, other_group); + undo_log.push(Undo::Extend { group_index: root_group, len: root_vec.len() }); + root_vec.extend_from_slice(other_vec); + + // We only need to add `other` if there is a watcher for it (there might only be + // watchers for the other variables in its group) + if self.reference_counts.get(other).map_or(false, |c| *c != 0) { + root_vec.push(other); + } + } + } + } + + /// Returns the variables that `root` were unified with. The returned list may or may not + /// contain `root` itself. + pub fn get(&self, root: T) -> &[T] { + match self.unified_vars.get(root) { + Some(group) => match self.groups.get(*group) { + Some(v) => v, + None => &[], + }, + None => &[], + } + } + + /// Returns true if `vid` is something that needs to be logged to a watcher + pub fn needs_log(&self, vid: T) -> bool { + !self.get(vid).is_empty() || self.reference_counts.get(vid).map_or(false, |c| *c != 0) + } + + /// Starts a watch on `index`. Any calls to `watch_variable` should be matched by call to + /// `unwatch_variable` when the watch is no longer needed + pub fn watch_variable(&mut self, index: T) { + self.reference_counts.ensure_contains_elem(index, || 0); + self.reference_counts[index] += 1; + } + + /// Removes a watch on `index` + pub fn unwatch_variable(&mut self, index: T) { + self.reference_counts[index] -= 1; + } +} + +impl Rollback> for UnifyLog { + fn reverse(&mut self, undo: Undo) { + match undo { + Undo::Extend { group_index, len } => self.groups[group_index].truncate(len as usize), + Undo::Move { index, old } => self.unified_vars[index] = old, + Undo::NewGroup { index } => { + self.groups.pop(); + self.unified_vars[index] = usize::max_value(); + } + } + } +} diff --git a/compiler/rustc_index/src/bit_set.rs b/compiler/rustc_index/src/bit_set.rs index 8e00e54650df0..3fd3b25beb974 100644 --- a/compiler/rustc_index/src/bit_set.rs +++ b/compiler/rustc_index/src/bit_set.rs @@ -59,6 +59,13 @@ impl BitSet { result } + pub fn resize(&mut self, domain_size: usize) { + let num_words = num_words(domain_size); + self.domain_size = domain_size; + self.words.resize(num_words, 0); + self.clear_excess_bits(); + } + /// Clear all elements. #[inline] pub fn clear(&mut self) { diff --git a/compiler/rustc_infer/src/infer/canonical/query_response.rs b/compiler/rustc_infer/src/infer/canonical/query_response.rs index 93e19521893ef..67b89bbb06a2f 100644 --- a/compiler/rustc_infer/src/infer/canonical/query_response.rs +++ b/compiler/rustc_infer/src/infer/canonical/query_response.rs @@ -122,7 +122,7 @@ impl<'cx, 'tcx> InferCtxt<'cx, 'tcx> { } // Anything left unselected *now* must be an ambiguity. - let ambig_errors = fulfill_cx.select_all_or_error(self).err().unwrap_or_else(Vec::new); + let ambig_errors = fulfill_cx.select_or_error(self).err().unwrap_or_else(Vec::new); debug!("ambig_errors = {:#?}", ambig_errors); let region_obligations = self.take_registered_region_obligations(); diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs index 6a1715ef81899..06aeab874efbb 100644 --- a/compiler/rustc_infer/src/infer/combine.rs +++ b/compiler/rustc_infer/src/infer/combine.rs @@ -76,6 +76,7 @@ impl<'infcx, 'tcx> InferCtxt<'infcx, 'tcx> { match (a.kind(), b.kind()) { // Relate integral variables to other types (&ty::Infer(ty::IntVar(a_id)), &ty::Infer(ty::IntVar(b_id))) => { + warn!("Unify int: {:?} {:?}", a_id, b_id); self.inner .borrow_mut() .int_unification_table() @@ -225,7 +226,7 @@ impl<'infcx, 'tcx> InferCtxt<'infcx, 'tcx> { fn unify_const_variable( &self, param_env: ty::ParamEnv<'tcx>, - target_vid: ty::ConstVid<'tcx>, + target_vid: ty::ConstVid, ct: &'tcx ty::Const<'tcx>, vid_is_expected: bool, ) -> RelateResult<'tcx, &'tcx ty::Const<'tcx>> { @@ -721,10 +722,14 @@ impl TypeRelation<'tcx> for Generalizer<'_, 'tcx> { if self.for_universe.can_name(universe) { Ok(c) } else { - let new_var_id = variable_table.new_key(ConstVarValue { - origin: var_value.origin, - val: ConstVariableValue::Unknown { universe: self.for_universe }, - }); + let new_var_id = variable_table + .new_key(ConstVarValue { + origin: var_value.origin, + val: ConstVariableValue::Unknown { + universe: self.for_universe, + }, + }) + .vid; Ok(self.tcx().mk_const_var(new_var_id, c.ty)) } } @@ -792,7 +797,7 @@ struct ConstInferUnifier<'cx, 'tcx> { /// The vid of the const variable that is in the process of being /// instantiated; if we find this within the const we are folding, /// that means we would have created a cyclic const. - target_vid: ty::ConstVid<'tcx>, + target_vid: ty::ConstVid, } // We use `TypeRelation` here to propagate `RelateResult` upwards. @@ -927,7 +932,7 @@ impl TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { // an inference variable which is unioned with `target_vid`. // // Not doing so can easily result in stack overflows. - if variable_table.unioned(self.target_vid, vid) { + if variable_table.unioned(self.target_vid.into(), vid) { return Err(TypeError::CyclicConst(c)); } @@ -942,7 +947,7 @@ impl TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { origin: var_value.origin, val: ConstVariableValue::Unknown { universe: self.for_universe }, }); - Ok(self.tcx().mk_const_var(new_var_id, c.ty)) + Ok(self.tcx().mk_const_var(new_var_id.vid, c.ty)) } } } diff --git a/compiler/rustc_infer/src/infer/freshen.rs b/compiler/rustc_infer/src/infer/freshen.rs index b3d7876c6e819..4082e673edc82 100644 --- a/compiler/rustc_infer/src/infer/freshen.rs +++ b/compiler/rustc_infer/src/infer/freshen.rs @@ -46,7 +46,7 @@ pub struct TypeFreshener<'a, 'tcx> { ty_freshen_count: u32, const_freshen_count: u32, ty_freshen_map: FxHashMap>, - const_freshen_map: FxHashMap, &'tcx ty::Const<'tcx>>, + const_freshen_map: FxHashMap>, } impl<'a, 'tcx> TypeFreshener<'a, 'tcx> { @@ -88,12 +88,12 @@ impl<'a, 'tcx> TypeFreshener<'a, 'tcx> { fn freshen_const( &mut self, opt_ct: Option<&'tcx ty::Const<'tcx>>, - key: ty::InferConst<'tcx>, + key: ty::InferConst, freshener: F, ty: Ty<'tcx>, ) -> &'tcx ty::Const<'tcx> where - F: FnOnce(u32) -> ty::InferConst<'tcx>, + F: FnOnce(u32) -> ty::InferConst, { if let Some(ct) = opt_ct { return ct.fold_with(self); diff --git a/compiler/rustc_infer/src/infer/fudge.rs b/compiler/rustc_infer/src/infer/fudge.rs index d7bc636db8f8f..7aa2a67dc6649 100644 --- a/compiler/rustc_infer/src/infer/fudge.rs +++ b/compiler/rustc_infer/src/infer/fudge.rs @@ -1,9 +1,11 @@ +use rustc_index::vec::Idx; +use rustc_middle::infer::unify_key::ConstVidEqKey; use rustc_middle::ty::fold::{TypeFoldable, TypeFolder}; use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid}; use super::type_variable::TypeVariableOrigin; use super::InferCtxt; -use super::{ConstVariableOrigin, RegionVariableOrigin, UnificationTable}; +use super::{ConstVariableOrigin, LoggedUnificationTable, RegionVariableOrigin}; use rustc_data_structures::snapshot_vec as sv; use rustc_data_structures::unify as ut; @@ -11,25 +13,26 @@ use ut::UnifyKey; use std::ops::Range; -fn vars_since_snapshot<'tcx, T>( - table: &mut UnificationTable<'_, 'tcx, T>, +fn vars_since_snapshot<'tcx, T, K>( + table: &mut LoggedUnificationTable<'_, 'tcx, T, K>, snapshot_var_len: usize, ) -> Range where T: UnifyKey, + K: Idx, super::UndoLog<'tcx>: From>>, { T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32) } fn const_vars_since_snapshot<'tcx>( - table: &mut UnificationTable<'_, 'tcx, ConstVid<'tcx>>, + table: &mut LoggedUnificationTable<'_, 'tcx, ConstVidEqKey<'tcx>, ConstVid>, snapshot_var_len: usize, -) -> (Range>, Vec) { +) -> (Range, Vec) { let range = vars_since_snapshot(table, snapshot_var_len); ( - range.start..range.end, - (range.start.index..range.end.index) + range.start.into()..range.end.into(), + (range.start.vid.index..range.end.vid.index) .map(|index| table.probe_value(ConstVid::from_index(index)).origin) .collect(), ) @@ -173,7 +176,7 @@ pub struct InferenceFudger<'a, 'tcx> { int_vars: Range, float_vars: Range, region_vars: (Range, Vec), - const_vars: (Range>, Vec), + const_vars: (Range, Vec), } impl<'a, 'tcx> TypeFolder<'tcx> for InferenceFudger<'a, 'tcx> { diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index ff7bbf0562f60..95dda45c17285 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -11,15 +11,19 @@ pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog}; use crate::traits::{self, ObligationCause, PredicateObligations, TraitEngine}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_data_structures::logged_unification_table as lut; +use rustc_data_structures::modified_set as ms; use rustc_data_structures::sync::Lrc; -use rustc_data_structures::undo_log::Rollback; +use rustc_data_structures::undo_log::{Rollback, UndoLogs}; use rustc_data_structures::unify as ut; use rustc_errors::DiagnosticBuilder; use rustc_hir as hir; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_middle::infer::canonical::{Canonical, CanonicalVarValues}; use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue}; -use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind, ToType}; +use rustc_middle::infer::unify_key::{ + ConstVariableOrigin, ConstVariableOriginKind, ConstVidEqKey, ToType, +}; use rustc_middle::mir; use rustc_middle::mir::interpret::EvalToConstValueResult; use rustc_middle::traits::select; @@ -82,12 +86,15 @@ pub type InferResult<'tcx, T> = Result, TypeError<'tcx>>; pub type Bound = Option; pub type UnitResult<'tcx> = RelateResult<'tcx, ()>; // "unify result" -pub type FixupResult<'tcx, T> = Result>; // "fixup result" +pub type FixupResult = Result; // "fixup result" pub(crate) type UnificationTable<'a, 'tcx, T> = ut::UnificationTable< ut::InPlace, &'a mut InferCtxtUndoLogs<'tcx>>, >; +pub(crate) type LoggedUnificationTable<'a, 'tcx, T, K = T> = + lut::LoggedUnificationTable<'a, T, K, &'a mut InferCtxtUndoLogs<'tcx>>; + /// How we should handle region solving. /// /// This is used so that the region values inferred by HIR region solving are @@ -145,13 +152,14 @@ pub struct InferCtxtInner<'tcx> { type_variable_storage: type_variable::TypeVariableStorage<'tcx>, /// Map from const parameter variable to the kind of const it represents. - const_unification_storage: ut::UnificationTableStorage>, + const_unification_storage: + lut::LoggedUnificationTableStorage, ty::ConstVid>, /// Map from integral variable to the kind of integer it represents. - int_unification_storage: ut::UnificationTableStorage, + int_unification_storage: lut::LoggedUnificationTableStorage, /// Map from floating variable to the kind of float it represents. - float_unification_storage: ut::UnificationTableStorage, + float_unification_storage: lut::LoggedUnificationTableStorage, /// Tracks the set of region variables and the constraints between them. /// This is initially `Some(_)` but when @@ -202,9 +210,9 @@ impl<'tcx> InferCtxtInner<'tcx> { projection_cache: Default::default(), type_variable_storage: type_variable::TypeVariableStorage::new(), undo_log: InferCtxtUndoLogs::default(), - const_unification_storage: ut::UnificationTableStorage::new(), - int_unification_storage: ut::UnificationTableStorage::new(), - float_unification_storage: ut::UnificationTableStorage::new(), + const_unification_storage: lut::LoggedUnificationTableStorage::new(), + int_unification_storage: lut::LoggedUnificationTableStorage::new(), + float_unification_storage: lut::LoggedUnificationTableStorage::new(), region_constraint_storage: Some(RegionConstraintStorage::new()), region_obligations: vec![], } @@ -226,41 +234,19 @@ impl<'tcx> InferCtxtInner<'tcx> { } #[inline] - fn int_unification_table( - &mut self, - ) -> ut::UnificationTable< - ut::InPlace< - ty::IntVid, - &mut ut::UnificationStorage, - &mut InferCtxtUndoLogs<'tcx>, - >, - > { + fn int_unification_table(&mut self) -> LoggedUnificationTable<'_, 'tcx, ty::IntVid> { self.int_unification_storage.with_log(&mut self.undo_log) } #[inline] - fn float_unification_table( - &mut self, - ) -> ut::UnificationTable< - ut::InPlace< - ty::FloatVid, - &mut ut::UnificationStorage, - &mut InferCtxtUndoLogs<'tcx>, - >, - > { + fn float_unification_table(&mut self) -> LoggedUnificationTable<'_, 'tcx, ty::FloatVid> { self.float_unification_storage.with_log(&mut self.undo_log) } #[inline] fn const_unification_table( &mut self, - ) -> ut::UnificationTable< - ut::InPlace< - ty::ConstVid<'tcx>, - &mut ut::UnificationStorage>, - &mut InferCtxtUndoLogs<'tcx>, - >, - > { + ) -> LoggedUnificationTable<'_, 'tcx, ConstVidEqKey<'tcx>, ty::ConstVid> { self.const_unification_storage.with_log(&mut self.undo_log) } @@ -494,11 +480,11 @@ pub enum NLLRegionVariableOrigin { // FIXME(eddyb) investigate overlap between this and `TyOrConstInferVar`. #[derive(Copy, Clone, Debug)] -pub enum FixupError<'tcx> { +pub enum FixupError { UnresolvedIntTy(IntVid), UnresolvedFloatTy(FloatVid), UnresolvedTy(TyVid), - UnresolvedConst(ConstVid<'tcx>), + UnresolvedConst(ConstVid), } /// See the `region_obligations` field for more information. @@ -509,7 +495,7 @@ pub struct RegionObligation<'tcx> { pub origin: SubregionOrigin<'tcx>, } -impl<'tcx> fmt::Display for FixupError<'tcx> { +impl fmt::Display for FixupError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::FixupError::*; @@ -777,6 +763,13 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { let mut inner = self.inner.borrow_mut(); inner.rollback_to(undo_snapshot); inner.unwrap_region_constraints().rollback_to(region_constraints_snapshot); + + if UndoLogs::>::num_open_snapshots(&inner.undo_log) == 0 { + inner.type_variables().clear_modified_set(); + inner.int_unification_table().clear_modified_set(); + inner.float_unification_table().clear_modified_set(); + inner.const_unification_table().clear_modified_set(); + } } fn commit_from(&self, snapshot: CombinedSnapshot<'a, 'tcx>) { @@ -1020,15 +1013,20 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { .inner .borrow_mut() .const_unification_table() - .new_key(ConstVarValue { origin, val: ConstVariableValue::Unknown { universe } }); + .new_key(ConstVarValue { origin, val: ConstVariableValue::Unknown { universe } }) + .vid; self.tcx.mk_const_var(vid, ty) } - pub fn next_const_var_id(&self, origin: ConstVariableOrigin) -> ConstVid<'tcx> { - self.inner.borrow_mut().const_unification_table().new_key(ConstVarValue { - origin, - val: ConstVariableValue::Unknown { universe: self.universe() }, - }) + pub fn next_const_var_id(&self, origin: ConstVariableOrigin) -> ConstVid { + self.inner + .borrow_mut() + .const_unification_table() + .new_key(ConstVarValue { + origin, + val: ConstVariableValue::Unknown { universe: self.universe() }, + }) + .vid } fn next_int_var_id(&self) -> IntVid { @@ -1133,11 +1131,15 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { ), span, }; - let const_var_id = - self.inner.borrow_mut().const_unification_table().new_key(ConstVarValue { + let const_var_id = self + .inner + .borrow_mut() + .const_unification_table() + .new_key(ConstVarValue { origin, val: ConstVariableValue::Unknown { universe: self.universe() }, - }); + }) + .vid; self.tcx.mk_const_var(const_var_id, self.tcx.type_of(param.def_id)).into() } } @@ -1343,7 +1345,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { pub fn probe_const_var( &self, - vid: ty::ConstVid<'tcx>, + vid: ty::ConstVid, ) -> Result<&'tcx ty::Const<'tcx>, ty::UniverseIndex> { match self.inner.borrow_mut().const_unification_table().probe_value(vid).val { ConstVariableValue::Known { value } => Ok(value), @@ -1351,7 +1353,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { } } - pub fn fully_resolve>(&self, value: &T) -> FixupResult<'tcx, T> { + pub fn fully_resolve>(&self, value: &T) -> FixupResult { /*! * Attempts to resolve all type/region/const variables in * `value`. Region inference must have been run already (e.g., @@ -1571,7 +1573,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { /// inference variables), and it handles both `Ty` and `ty::Const` without /// having to resort to storing full `GenericArg`s in `stalled_on`. #[inline(always)] - pub fn ty_or_const_infer_var_changed(&self, infer_var: TyOrConstInferVar<'tcx>) -> bool { + pub fn ty_or_const_infer_var_changed(&self, infer_var: TyOrConstInferVar) -> bool { match infer_var { TyOrConstInferVar::Ty(v) => { use self::type_variable::TypeVariableValue; @@ -1611,12 +1613,107 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { } } } + + pub fn notify_watcher(&self, offset: &WatcherOffset, mut f: impl FnMut(TyOrConstInferVar)) { + let mut inner = self.inner.borrow_mut(); + + inner.type_variables().notify_watcher(&offset.ty_offset, |v| f(TyOrConstInferVar::Ty(v))); + + inner + .int_unification_table() + .notify_watcher(&offset.int_offset, |vid| f(TyOrConstInferVar::TyInt(vid))); + + inner + .float_unification_table() + .notify_watcher(&offset.float_offset, |vid| f(TyOrConstInferVar::TyFloat(vid))); + + inner + .const_unification_table() + .notify_watcher(&offset.const_offset, |vid| f(TyOrConstInferVar::Const(vid))); + } + + pub fn register_unify_watcher(&self) -> WatcherOffset { + let mut inner = self.inner.borrow_mut(); + WatcherOffset { + ty_offset: inner.type_variables().register_unify_watcher(), + + int_offset: inner.int_unification_table().register_watcher(), + + float_offset: inner.float_unification_table().register_watcher(), + + const_offset: inner.const_unification_table().register_watcher(), + } + } + + pub fn deregister_unify_watcher(&self, offset: WatcherOffset) { + let mut inner = self.inner.borrow_mut(); + + inner.type_variables().deregister_unify_watcher(offset.ty_offset); + + inner.int_unification_table().deregister_watcher(offset.int_offset); + + inner.float_unification_table().deregister_watcher(offset.float_offset); + + inner.const_unification_table().deregister_watcher(offset.const_offset); + } + + pub fn watch_variable(&self, infer: TyOrConstInferVar) { + let mut inner = self.inner.borrow_mut(); + match infer { + TyOrConstInferVar::Ty(v) => inner.type_variables().watch_variable(v), + + TyOrConstInferVar::TyInt(v) => inner.int_unification_table().watch_variable(v), + + TyOrConstInferVar::TyFloat(v) => inner.float_unification_table().watch_variable(v), + + TyOrConstInferVar::Const(v) => inner.const_unification_table().watch_variable(v), + } + } + + pub fn unwatch_variable(&self, infer: TyOrConstInferVar) { + let mut inner = self.inner.borrow_mut(); + match infer { + TyOrConstInferVar::Ty(v) => inner.type_variables().unwatch_variable(v), + + TyOrConstInferVar::TyInt(v) => inner.int_unification_table().unwatch_variable(v), + + TyOrConstInferVar::TyFloat(v) => inner.float_unification_table().unwatch_variable(v), + + TyOrConstInferVar::Const(v) => inner.const_unification_table().unwatch_variable(v), + } + } + + pub fn root_ty_or_const(&self, infer: TyOrConstInferVar) -> TyOrConstInferVar { + let mut inner = self.inner.borrow_mut(); + match infer { + TyOrConstInferVar::Ty(v) => TyOrConstInferVar::Ty(inner.type_variables().root_var(v)), + + TyOrConstInferVar::TyInt(v) => { + TyOrConstInferVar::TyInt(inner.int_unification_table().find(v)) + } + + TyOrConstInferVar::TyFloat(v) => { + TyOrConstInferVar::TyFloat(inner.float_unification_table().find(v)) + } + + TyOrConstInferVar::Const(v) => { + TyOrConstInferVar::Const(inner.const_unification_table().find(v).vid) + } + } + } +} + +pub struct WatcherOffset { + ty_offset: ms::Offset, + int_offset: ms::Offset, + float_offset: ms::Offset, + const_offset: ms::Offset, } /// Helper for `ty_or_const_infer_var_changed` (see comment on that), currently /// used only for `traits::fulfill`'s list of `stalled_on` inference variables. -#[derive(Copy, Clone, Debug)] -pub enum TyOrConstInferVar<'tcx> { +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum TyOrConstInferVar { /// Equivalent to `ty::Infer(ty::TyVar(_))`. Ty(TyVid), /// Equivalent to `ty::Infer(ty::IntVar(_))`. @@ -1625,10 +1722,10 @@ pub enum TyOrConstInferVar<'tcx> { TyFloat(FloatVid), /// Equivalent to `ty::ConstKind::Infer(ty::InferConst::Var(_))`. - Const(ConstVid<'tcx>), + Const(ConstVid), } -impl TyOrConstInferVar<'tcx> { +impl TyOrConstInferVar { /// Tries to extract an inference variable from a type or a constant, returns `None` /// for types other than `ty::Infer(_)` (or `InferTy::Fresh*`) and /// for constants other than `ty::ConstKind::Infer(_)` (or `InferConst::Fresh`). diff --git a/compiler/rustc_infer/src/infer/nll_relate/mod.rs b/compiler/rustc_infer/src/infer/nll_relate/mod.rs index abdd6edea9024..9af9fb286ae8c 100644 --- a/compiler/rustc_infer/src/infer/nll_relate/mod.rs +++ b/compiler/rustc_infer/src/infer/nll_relate/mod.rs @@ -978,10 +978,12 @@ where match var_value.val.known() { Some(u) => self.relate(u, u), None => { - let new_var_id = variable_table.new_key(ConstVarValue { - origin: var_value.origin, - val: ConstVariableValue::Unknown { universe: self.universe }, - }); + let new_var_id = variable_table + .new_key(ConstVarValue { + origin: var_value.origin, + val: ConstVariableValue::Unknown { universe: self.universe }, + }) + .vid; Ok(self.tcx().mk_const_var(new_var_id, a.ty)) } } diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs index 337772d70b823..fff2d11dd4db3 100644 --- a/compiler/rustc_infer/src/infer/resolve.rs +++ b/compiler/rustc_infer/src/infer/resolve.rs @@ -162,7 +162,7 @@ impl<'a, 'tcx> TypeVisitor<'tcx> for UnresolvedTypeFinder<'a, 'tcx> { /// Full type resolution replaces all type and region variables with /// their concrete results. If any variable cannot be replaced (never unified, etc) /// then an `Err` result is returned. -pub fn fully_resolve<'a, 'tcx, T>(infcx: &InferCtxt<'a, 'tcx>, value: &T) -> FixupResult<'tcx, T> +pub fn fully_resolve<'a, 'tcx, T>(infcx: &InferCtxt<'a, 'tcx>, value: &T) -> FixupResult where T: TypeFoldable<'tcx>, { @@ -178,7 +178,7 @@ where // `err` field is not enforceable otherwise. struct FullTypeResolver<'a, 'tcx> { infcx: &'a InferCtxt<'a, 'tcx>, - err: Option>, + err: Option, } impl<'a, 'tcx> TypeFolder<'tcx> for FullTypeResolver<'a, 'tcx> { diff --git a/compiler/rustc_infer/src/infer/type_variable.rs b/compiler/rustc_infer/src/infer/type_variable.rs index 35b97fff3da1f..0eef669fcf396 100644 --- a/compiler/rustc_infer/src/infer/type_variable.rs +++ b/compiler/rustc_infer/src/infer/type_variable.rs @@ -5,8 +5,11 @@ use rustc_span::Span; use crate::infer::InferCtxtUndoLogs; +use rustc_data_structures::logged_unification_table as lut; +use rustc_data_structures::modified_set as ms; use rustc_data_structures::snapshot_vec as sv; use rustc_data_structures::unify as ut; +use rustc_data_structures::unify_log as ul; use std::cmp; use std::marker::PhantomData; use std::ops::Range; @@ -15,14 +18,14 @@ use rustc_data_structures::undo_log::{Rollback, UndoLogs}; /// Represents a single undo-able action that affects a type inference variable. pub(crate) enum UndoLog<'tcx> { - EqRelation(sv::UndoLog>>), + EqRelation(lut::UndoLog, ty::TyVid>), SubRelation(sv::UndoLog>), Values(sv::UndoLog), } /// Convert from a specific kind of undo to the more general UndoLog -impl<'tcx> From>>> for UndoLog<'tcx> { - fn from(l: sv::UndoLog>>) -> Self { +impl<'tcx> From, ty::TyVid>> for UndoLog<'tcx> { + fn from(l: lut::UndoLog, ty::TyVid>) -> Self { UndoLog::EqRelation(l) } } @@ -48,6 +51,24 @@ impl<'tcx> From for UndoLog<'tcx> { } } +impl<'tcx> From>>> for UndoLog<'tcx> { + fn from(l: sv::UndoLog>>) -> Self { + UndoLog::EqRelation(l.into()) + } +} + +impl From> for UndoLog<'_> { + fn from(l: ul::Undo) -> Self { + UndoLog::EqRelation(l.into()) + } +} + +impl From> for UndoLog<'_> { + fn from(l: ms::Undo) -> Self { + UndoLog::EqRelation(l.into()) + } +} + impl<'tcx> Rollback> for TypeVariableStorage<'tcx> { fn reverse(&mut self, undo: UndoLog<'tcx>) { match undo { @@ -64,7 +85,7 @@ pub struct TypeVariableStorage<'tcx> { /// Two variables are unified in `eq_relations` when we have a /// constraint `?X == ?Y`. This table also stores, for each key, /// the known value. - eq_relations: ut::UnificationTableStorage>, + eq_relations: lut::LoggedUnificationTableStorage, TyVid>, /// Two variables are unified in `sub_relations` when we have a /// constraint `?X <: ?Y` *or* a constraint `?Y <: ?X`. This second @@ -151,12 +172,13 @@ pub(crate) struct Instantiate { } pub(crate) struct Delegate; +pub(crate) struct UnifiedVarsDelegate; impl<'tcx> TypeVariableStorage<'tcx> { pub fn new() -> TypeVariableStorage<'tcx> { TypeVariableStorage { values: sv::SnapshotVecStorage::new(), - eq_relations: ut::UnificationTableStorage::new(), + eq_relations: lut::LoggedUnificationTableStorage::new(), sub_relations: ut::UnificationTableStorage::new(), } } @@ -193,7 +215,7 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { pub fn equate(&mut self, a: ty::TyVid, b: ty::TyVid) { debug_assert!(self.probe(a).is_unknown()); debug_assert!(self.probe(b).is_unknown()); - self.eq_relations().union(a, b); + self.eq_relations().unify(a, b); self.sub_relations().union(a, b); } @@ -203,6 +225,7 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { pub fn sub(&mut self, a: ty::TyVid, b: ty::TyVid) { debug_assert!(self.probe(a).is_unknown()); debug_assert!(self.probe(b).is_unknown()); + self.sub_relations().union(a, b); } @@ -323,7 +346,9 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { } #[inline] - fn eq_relations(&mut self) -> super::UnificationTable<'_, 'tcx, TyVidEqKey<'tcx>> { + fn eq_relations( + &mut self, + ) -> super::LoggedUnificationTable<'_, 'tcx, TyVidEqKey<'tcx>, ty::TyVid> { self.storage.eq_relations.with_log(self.undo_log) } @@ -376,10 +401,8 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { if vid.index < new_elem_threshold { // quick check to see if this variable was // created since the snapshot started or not. - let mut eq_relations = ut::UnificationTable::with_log( - &mut self.storage.eq_relations, - &mut *self.undo_log, - ); + let mut eq_relations = + self.storage.eq_relations.with_log(&mut *self.undo_log); let escaping_type = match eq_relations.probe_value(vid) { TypeVariableValue::Unknown { .. } => bug!(), TypeVariableValue::Known { value } => value, @@ -409,6 +432,30 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { }) .collect() } + + pub fn clear_modified_set(&mut self) { + self.eq_relations().clear_modified_set(); + } + + pub fn notify_watcher(&mut self, offset: &ms::Offset, f: impl FnMut(ty::TyVid)) { + self.eq_relations().notify_watcher(offset, f) + } + + pub fn register_unify_watcher(&mut self) -> ms::Offset { + self.eq_relations().register_watcher() + } + + pub fn deregister_unify_watcher(&mut self, offset: ms::Offset) { + self.eq_relations().deregister_watcher(offset); + } + + pub fn watch_variable(&mut self, vid: ty::TyVid) { + self.eq_relations().watch_variable(vid); + } + + pub fn unwatch_variable(&mut self, vid: ty::TyVid) { + self.eq_relations().unwatch_variable(vid); + } } impl sv::SnapshotVecDelegate for Delegate { @@ -431,6 +478,13 @@ impl sv::SnapshotVecDelegate for Delegate { } } +impl sv::SnapshotVecDelegate for UnifiedVarsDelegate { + type Value = Vec; + type Undo = (); + + fn reverse(_values: &mut Vec, _action: ()) {} +} + /////////////////////////////////////////////////////////////////////////// /// These structs (a newtyped TyVid) are used as the unification key @@ -450,6 +504,12 @@ impl<'tcx> From for TyVidEqKey<'tcx> { } } +impl<'tcx> From> for ty::TyVid { + fn from(vid: TyVidEqKey<'tcx>) -> Self { + vid.vid + } +} + impl<'tcx> ut::UnifyKey for TyVidEqKey<'tcx> { type Value = TypeVariableValue<'tcx>; fn index(&self) -> u32 { diff --git a/compiler/rustc_infer/src/infer/undo_log.rs b/compiler/rustc_infer/src/infer/undo_log.rs index 2cfd6bb904c41..b344816694746 100644 --- a/compiler/rustc_infer/src/infer/undo_log.rs +++ b/compiler/rustc_infer/src/infer/undo_log.rs @@ -1,8 +1,12 @@ use std::marker::PhantomData; +use rustc_data_structures::logged_unification_table as lut; +use rustc_data_structures::modified_set as ms; use rustc_data_structures::snapshot_vec as sv; use rustc_data_structures::undo_log::{Rollback, UndoLogs}; use rustc_data_structures::unify as ut; +use rustc_data_structures::unify_log as ul; +use rustc_middle::infer::unify_key; use rustc_middle::ty; use crate::{ @@ -18,9 +22,9 @@ pub struct Snapshot<'tcx> { /// Records the 'undo' data fora single operation that affects some form of inference variable. pub(crate) enum UndoLog<'tcx> { TypeVariables(type_variable::UndoLog<'tcx>), - ConstUnificationTable(sv::UndoLog>>), - IntUnificationTable(sv::UndoLog>), - FloatUnificationTable(sv::UndoLog>), + ConstUnificationTable(lut::UndoLog, ty::ConstVid>), + IntUnificationTable(lut::UndoLog), + FloatUnificationTable(lut::UndoLog), RegionConstraintCollector(region_constraints::UndoLog<'tcx>), RegionUnificationTable(sv::UndoLog>), ProjectionCache(traits::UndoLog<'tcx>), @@ -40,26 +44,35 @@ macro_rules! impl_from { } // Upcast from a single kind of "undoable action" to the general enum + impl_from! { RegionConstraintCollector(region_constraints::UndoLog<'tcx>), TypeVariables(type_variable::UndoLog<'tcx>), + TypeVariables(lut::UndoLog, ty::TyVid>), TypeVariables(sv::UndoLog>>), TypeVariables(sv::UndoLog>), TypeVariables(sv::UndoLog), TypeVariables(type_variable::Instantiate), + TypeVariables(ms::Undo), + TypeVariables(ul::Undo), IntUnificationTable(sv::UndoLog>), + IntUnificationTable(ms::Undo), + IntUnificationTable(ul::Undo), FloatUnificationTable(sv::UndoLog>), + FloatUnificationTable(ms::Undo), + FloatUnificationTable(ul::Undo), - ConstUnificationTable(sv::UndoLog>>), + ConstUnificationTable(sv::UndoLog>>), + ConstUnificationTable(ms::Undo), + ConstUnificationTable(ul::Undo), RegionUnificationTable(sv::UndoLog>), ProjectionCache(traits::UndoLog<'tcx>), } -/// The Rollback trait defines how to rollback a particular action. impl<'tcx> Rollback> for InferCtxtInner<'tcx> { fn reverse(&mut self, undo: UndoLog<'tcx>) { match undo { diff --git a/compiler/rustc_infer/src/traits/engine.rs b/compiler/rustc_infer/src/traits/engine.rs index 2710debea9478..2ce64335e7920 100644 --- a/compiler/rustc_infer/src/traits/engine.rs +++ b/compiler/rustc_infer/src/traits/engine.rs @@ -6,7 +6,7 @@ use rustc_middle::ty::{self, ToPredicate, Ty, WithConstness}; use super::FulfillmentError; use super::{ObligationCause, PredicateObligation}; -pub trait TraitEngine<'tcx>: 'tcx { +pub trait TraitEngine<'tcx> { fn normalize_projection_type( &mut self, infcx: &InferCtxt<'_, 'tcx>, @@ -44,19 +44,125 @@ pub trait TraitEngine<'tcx>: 'tcx { obligation: PredicateObligation<'tcx>, ); - fn select_all_or_error( + fn select_or_error( &mut self, infcx: &InferCtxt<'_, 'tcx>, ) -> Result<(), Vec>>; + fn select_all_or_error( + mut self, + infcx: &InferCtxt<'_, 'tcx>, + ) -> Result<(), Vec>> + where + Self: Sized, + { + let result = self.select_or_error(infcx); + self.deregister(infcx); + result + } + fn select_where_possible( &mut self, infcx: &InferCtxt<'_, 'tcx>, ) -> Result<(), Vec>>; + fn select_all_where_possible( + mut self, + infcx: &InferCtxt<'_, 'tcx>, + ) -> Result<(), Vec>> + where + Self: Sized, + { + let result = self.select_where_possible(infcx); + self.deregister(infcx); + result + } + + fn deregister(&mut self, _infcx: &InferCtxt<'_, 'tcx>) {} + fn pending_obligations(&self) -> Vec>; } +impl TraitEngine<'tcx> for Box +where + T: ?Sized + TraitEngine<'tcx>, +{ + fn normalize_projection_type( + &mut self, + infcx: &InferCtxt<'_, 'tcx>, + param_env: ty::ParamEnv<'tcx>, + projection_ty: ty::ProjectionTy<'tcx>, + cause: ObligationCause<'tcx>, + ) -> Ty<'tcx> { + T::normalize_projection_type(self, infcx, param_env, projection_ty, cause) + } + + fn register_bound( + &mut self, + infcx: &InferCtxt<'_, 'tcx>, + param_env: ty::ParamEnv<'tcx>, + ty: Ty<'tcx>, + def_id: DefId, + cause: ObligationCause<'tcx>, + ) { + T::register_bound(self, infcx, param_env, ty, def_id, cause) + } + + fn register_predicate_obligation( + &mut self, + infcx: &InferCtxt<'_, 'tcx>, + obligation: PredicateObligation<'tcx>, + ) { + T::register_predicate_obligation(self, infcx, obligation) + } + + fn select_or_error( + &mut self, + infcx: &InferCtxt<'_, 'tcx>, + ) -> Result<(), Vec>> { + T::select_or_error(self, infcx) + } + + fn select_all_or_error( + mut self, + infcx: &InferCtxt<'_, 'tcx>, + ) -> Result<(), Vec>> + where + Self: Sized, + { + let result = self.select_or_error(infcx); + self.deregister(infcx); + result + } + + fn select_where_possible( + &mut self, + infcx: &InferCtxt<'_, 'tcx>, + ) -> Result<(), Vec>> { + T::select_where_possible(self, infcx) + } + + fn select_all_where_possible( + mut self, + infcx: &InferCtxt<'_, 'tcx>, + ) -> Result<(), Vec>> + where + Self: Sized, + { + let result = self.select_where_possible(infcx); + self.deregister(infcx); + result + } + + fn deregister(&mut self, infcx: &InferCtxt<'_, 'tcx>) { + T::deregister(self, infcx) + } + + fn pending_obligations(&self) -> Vec> { + T::pending_obligations(self) + } +} + pub trait TraitEngineExt<'tcx> { fn register_predicate_obligations( &mut self, diff --git a/compiler/rustc_middle/src/ich/impls_ty.rs b/compiler/rustc_middle/src/ich/impls_ty.rs index 8f15c99f951fe..bebb394939466 100644 --- a/compiler/rustc_middle/src/ich/impls_ty.rs +++ b/compiler/rustc_middle/src/ich/impls_ty.rs @@ -104,7 +104,7 @@ impl<'a> HashStable> for ty::RegionVid { } } -impl<'a, 'tcx> HashStable> for ty::ConstVid<'tcx> { +impl<'a> HashStable> for ty::ConstVid { #[inline] fn hash_stable(&self, hcx: &mut StableHashingContext<'a>, hasher: &mut StableHasher) { self.index.hash_stable(hcx, hasher); diff --git a/compiler/rustc_middle/src/infer/unify_key.rs b/compiler/rustc_middle/src/infer/unify_key.rs index 16e9aafb25a54..431651dc0ea80 100644 --- a/compiler/rustc_middle/src/infer/unify_key.rs +++ b/compiler/rustc_middle/src/infer/unify_key.rs @@ -1,9 +1,10 @@ use crate::ty::{self, FloatVarValue, InferConst, IntVarValue, Ty, TyCtxt}; +use rustc_data_structures::logged_unification_table::LoggedUnificationTable; +use rustc_data_structures::modified_set as ms; use rustc_data_structures::snapshot_vec; use rustc_data_structures::undo_log::UndoLogs; -use rustc_data_structures::unify::{ - self, EqUnifyValue, InPlace, NoError, UnificationTable, UnifyKey, UnifyValue, -}; +use rustc_data_structures::unify::{self, EqUnifyValue, NoError, UnifyKey, UnifyValue}; +use rustc_data_structures::unify_log as ul; use rustc_span::def_id::DefId; use rustc_span::symbol::Symbol; use rustc_span::Span; @@ -159,13 +160,44 @@ pub struct ConstVarValue<'tcx> { pub val: ConstVariableValue<'tcx>, } -impl<'tcx> UnifyKey for ty::ConstVid<'tcx> { - type Value = ConstVarValue<'tcx>; +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, TyEncodable, TyDecodable)] +pub struct ConstVidEqKey<'tcx> { + pub vid: ty::ConstVid, + pub phantom: PhantomData<&'tcx ()>, +} + +impl From for ConstVidEqKey<'_> { + fn from(vid: ty::ConstVid) -> Self { + Self { vid, phantom: PhantomData } + } +} + +impl<'tcx> From> for ty::ConstVid { + fn from(vid: ConstVidEqKey<'tcx>) -> Self { + vid.vid + } +} + +impl UnifyKey for ty::ConstVid { + type Value = (); fn index(&self) -> u32 { self.index } fn from_index(i: u32) -> Self { - ty::ConstVid { index: i, phantom: PhantomData } + ty::ConstVid { index: i } + } + fn tag() -> &'static str { + "ConstVid" + } +} + +impl<'tcx> UnifyKey for ConstVidEqKey<'tcx> { + type Value = ConstVarValue<'tcx>; + fn index(&self) -> u32 { + self.vid.index + } + fn from_index(i: u32) -> Self { + ConstVidEqKey { vid: ty::ConstVid { index: i }, phantom: PhantomData } } fn tag() -> &'static str { "ConstVid" @@ -207,13 +239,14 @@ impl<'tcx> UnifyValue for ConstVarValue<'tcx> { impl<'tcx> EqUnifyValue for &'tcx ty::Const<'tcx> {} -pub fn replace_if_possible( - table: &mut UnificationTable, V, L>>, +pub fn replace_if_possible( + table: &mut LoggedUnificationTable<'_, ConstVidEqKey<'tcx>, ty::ConstVid, L>, c: &'tcx ty::Const<'tcx>, ) -> &'tcx ty::Const<'tcx> where - V: snapshot_vec::VecLike>>, - L: UndoLogs>>>, + L: UndoLogs>>> + + UndoLogs> + + UndoLogs>, { if let ty::Const { val: ty::ConstKind::Infer(InferConst::Var(vid)), .. } = c { match table.probe_value(*vid).val.known() { diff --git a/compiler/rustc_middle/src/ty/consts/kind.rs b/compiler/rustc_middle/src/ty/consts/kind.rs index ede28522000af..a9253b568708e 100644 --- a/compiler/rustc_middle/src/ty/consts/kind.rs +++ b/compiler/rustc_middle/src/ty/consts/kind.rs @@ -17,7 +17,7 @@ pub enum ConstKind<'tcx> { Param(ty::ParamConst), /// Infer the value of the const. - Infer(InferConst<'tcx>), + Infer(InferConst), /// Bound const variable, used only when preparing a trait query. Bound(ty::DebruijnIndex, ty::BoundVar), @@ -70,9 +70,9 @@ impl<'tcx> ConstKind<'tcx> { /// An inference variable for a const, for use in const generics. #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, TyEncodable, TyDecodable, Hash)] #[derive(HashStable)] -pub enum InferConst<'tcx> { +pub enum InferConst { /// Infer the value of the const. - Var(ty::ConstVid<'tcx>), + Var(ty::ConstVid), /// A fresh const variable. See `infer::freshen` for more details. Fresh(u32), } diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index f6ea6743a0e04..44ce8ff77216c 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -2318,7 +2318,7 @@ impl<'tcx> TyCtxt<'tcx> { } #[inline] - pub fn mk_const_var(self, v: ConstVid<'tcx>, ty: Ty<'tcx>) -> &'tcx Const<'tcx> { + pub fn mk_const_var(self, v: ConstVid, ty: Ty<'tcx>) -> &'tcx Const<'tcx> { self.mk_const(ty::Const { val: ty::ConstKind::Infer(InferConst::Var(v)), ty }) } @@ -2338,7 +2338,7 @@ impl<'tcx> TyCtxt<'tcx> { } #[inline] - pub fn mk_const_infer(self, ic: InferConst<'tcx>, ty: Ty<'tcx>) -> &'tcx ty::Const<'tcx> { + pub fn mk_const_infer(self, ic: InferConst, ty: Ty<'tcx>) -> &'tcx ty::Const<'tcx> { self.mk_const(ty::Const { val: ty::ConstKind::Infer(ic), ty }) } diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index 53521d0e9f332..c81533fbfa5b0 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -133,7 +133,7 @@ impl fmt::Debug for ty::TyVid { } } -impl<'tcx> fmt::Debug for ty::ConstVid<'tcx> { +impl fmt::Debug for ty::ConstVid { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "_#{}c", self.index) } @@ -1112,7 +1112,7 @@ impl<'tcx> TypeFoldable<'tcx> for ty::ConstKind<'tcx> { } } -impl<'tcx> TypeFoldable<'tcx> for InferConst<'tcx> { +impl<'tcx> TypeFoldable<'tcx> for InferConst { fn super_fold_with>(&self, _folder: &mut F) -> Self { *self } diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index 0fd48d0928257..f791696bf16d9 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -23,7 +23,6 @@ use rustc_target::abi::VariantIdx; use rustc_target::spec::abi; use std::borrow::Cow; use std::cmp::Ordering; -use std::marker::PhantomData; use std::ops::Range; use ty::util::IntTypeExt; @@ -1463,10 +1462,33 @@ pub struct TyVid { pub index: u32, } +impl Idx for TyVid { + #[inline] + fn new(idx: usize) -> Self { + assert!(idx <= u32::max_value() as usize); + TyVid { index: idx as u32 } + } + #[inline] + fn index(self) -> usize { + self.index as usize + } +} + #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, TyEncodable, TyDecodable)] -pub struct ConstVid<'tcx> { +pub struct ConstVid { pub index: u32, - pub phantom: PhantomData<&'tcx ()>, +} + +impl Idx for ConstVid { + #[inline] + fn new(idx: usize) -> Self { + assert!(idx <= u32::max_value() as usize); + Self { index: idx as u32 } + } + #[inline] + fn index(self) -> usize { + self.index as usize + } } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, TyEncodable, TyDecodable)] @@ -1474,11 +1496,35 @@ pub struct IntVid { pub index: u32, } +impl Idx for IntVid { + #[inline] + fn new(idx: usize) -> Self { + assert!(idx <= u32::max_value() as usize); + IntVid { index: idx as u32 } + } + #[inline] + fn index(self) -> usize { + self.index as usize + } +} + #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, TyEncodable, TyDecodable)] pub struct FloatVid { pub index: u32, } +impl Idx for FloatVid { + #[inline] + fn new(idx: usize) -> Self { + assert!(idx <= u32::max_value() as usize); + FloatVid { index: idx as u32 } + } + #[inline] + fn index(self) -> usize { + self.index as usize + } +} + rustc_index::newtype_index! { pub struct RegionVid { DEBUG_FORMAT = custom, diff --git a/compiler/rustc_trait_selection/src/autoderef.rs b/compiler/rustc_trait_selection/src/autoderef.rs index b9c5123e49a0e..4f6fadcdd78f3 100644 --- a/compiler/rustc_trait_selection/src/autoderef.rs +++ b/compiler/rustc_trait_selection/src/autoderef.rs @@ -158,9 +158,11 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> { // but that's not a reason for an ICE (`predicate_may_hold` is conservative // by design). debug!("overloaded_deref_ty: encountered errors {:?} while fulfilling", e); + fulfillcx.deregister(&self.infcx); return None; } let obligations = fulfillcx.pending_obligations(); + fulfillcx.deregister(&self.infcx); debug!("overloaded_deref_ty({:?}) = ({:?}, {:?})", ty, normalized_ty, obligations); self.state.obligations.extend(obligations); diff --git a/compiler/rustc_trait_selection/src/infer.rs b/compiler/rustc_trait_selection/src/infer.rs index 4ec1b29bca4f1..6252e026d1ca7 100644 --- a/compiler/rustc_trait_selection/src/infer.rs +++ b/compiler/rustc_trait_selection/src/infer.rs @@ -126,11 +126,13 @@ impl<'tcx> InferCtxtBuilderExt<'tcx> for InferCtxtBuilder<'tcx> { |ref infcx, key, canonical_inference_vars| { let mut fulfill_cx = TraitEngine::new(infcx.tcx); let value = operation(infcx, &mut *fulfill_cx, key)?; - infcx.make_canonicalized_query_response( + let x = infcx.make_canonicalized_query_response( canonical_inference_vars, value, &mut *fulfill_cx, - ) + ); + fulfill_cx.deregister(infcx); + x }, ) } diff --git a/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs b/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs index adc8ae5908656..992aa33210337 100644 --- a/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs +++ b/compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs @@ -42,7 +42,7 @@ impl TraitEngine<'tcx> for FulfillmentContext<'tcx> { self.obligations.insert(obligation); } - fn select_all_or_error( + fn select_or_error( &mut self, infcx: &InferCtxt<'_, 'tcx>, ) -> Result<(), Vec>> { diff --git a/compiler/rustc_trait_selection/src/traits/codegen.rs b/compiler/rustc_trait_selection/src/traits/codegen.rs index 3cb6ec8626186..8af576d4ea3fd 100644 --- a/compiler/rustc_trait_selection/src/traits/codegen.rs +++ b/compiler/rustc_trait_selection/src/traits/codegen.rs @@ -89,7 +89,7 @@ pub fn codegen_fulfill_obligation<'tcx>( debug!("fulfill_obligation: register_predicate_obligation {:?}", predicate); fulfill_cx.register_predicate_obligation(&infcx, predicate); }); - let impl_source = drain_fulfillment_cx_or_panic(&infcx, &mut fulfill_cx, &impl_source); + let impl_source = drain_fulfillment_cx_or_panic(&infcx, fulfill_cx, &impl_source); info!("Cache miss: {:?} => {:?}", trait_ref, impl_source); Ok(impl_source) @@ -109,7 +109,7 @@ pub fn codegen_fulfill_obligation<'tcx>( /// the complete picture of the type. fn drain_fulfillment_cx_or_panic( infcx: &InferCtxt<'_, 'tcx>, - fulfill_cx: &mut FulfillmentContext<'tcx>, + fulfill_cx: FulfillmentContext<'tcx>, result: &T, ) -> T where diff --git a/compiler/rustc_trait_selection/src/traits/engine.rs b/compiler/rustc_trait_selection/src/traits/engine.rs index 4d4778869794b..18fe0cd4ebc88 100644 --- a/compiler/rustc_trait_selection/src/traits/engine.rs +++ b/compiler/rustc_trait_selection/src/traits/engine.rs @@ -1,13 +1,17 @@ +use rustc_infer::infer::InferCtxt; use rustc_middle::ty::TyCtxt; use super::TraitEngine; use super::{ChalkFulfillmentContext, FulfillmentContext}; -pub trait TraitEngineExt<'tcx> { +pub trait TraitEngineExt<'tcx>: TraitEngine<'tcx> { fn new(tcx: TyCtxt<'tcx>) -> Box; + fn new_with_deregister<'cx>( + infcx: &'cx InferCtxt<'cx, 'tcx>, + ) -> DeregisterOnDropEngine<'cx, 'tcx, Box>; } -impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> { +impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> + 'tcx { fn new(tcx: TyCtxt<'tcx>) -> Box { if tcx.sess.opts.debugging_opts.chalk { Box::new(ChalkFulfillmentContext::new()) @@ -15,4 +19,46 @@ impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> { Box::new(FulfillmentContext::new()) } } + fn new_with_deregister<'cx>( + infcx: &'cx InferCtxt<'cx, 'tcx>, + ) -> DeregisterOnDropEngine<'cx, 'tcx, Box> { + DeregisterOnDropEngine { engine: Self::new(infcx.tcx), infcx } + } +} + +/// Deregisters any variable watches on drop automatically +pub struct DeregisterOnDropEngine<'cx, 'tcx, T> +where + T: TraitEngine<'tcx>, +{ + infcx: &'cx InferCtxt<'cx, 'tcx>, + engine: T, +} + +impl<'tcx, T> Drop for DeregisterOnDropEngine<'_, 'tcx, T> +where + T: TraitEngine<'tcx>, +{ + fn drop(&mut self) { + self.engine.deregister(self.infcx) + } +} + +impl<'tcx, T> std::ops::Deref for DeregisterOnDropEngine<'_, 'tcx, T> +where + T: TraitEngine<'tcx>, +{ + type Target = T; + fn deref(&self) -> &Self::Target { + &self.engine + } +} + +impl<'tcx, T> std::ops::DerefMut for DeregisterOnDropEngine<'_, 'tcx, T> +where + T: TraitEngine<'tcx>, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.engine + } } diff --git a/compiler/rustc_trait_selection/src/traits/fulfill.rs b/compiler/rustc_trait_selection/src/traits/fulfill.rs index 9a8b5534dfe83..271f335f7f23c 100644 --- a/compiler/rustc_trait_selection/src/traits/fulfill.rs +++ b/compiler/rustc_trait_selection/src/traits/fulfill.rs @@ -1,6 +1,7 @@ -use crate::infer::{InferCtxt, TyOrConstInferVar}; +use crate::infer::{InferCtxt, TyOrConstInferVar, WatcherOffset}; +use rustc_data_structures::captures::Captures; use rustc_data_structures::obligation_forest::ProcessResult; -use rustc_data_structures::obligation_forest::{Error, ForestObligation, Outcome}; +use rustc_data_structures::obligation_forest::{DoCompleted, Error, ForestObligation}; use rustc_data_structures::obligation_forest::{ObligationForest, ObligationProcessor}; use rustc_errors::ErrorReported; use rustc_infer::traits::{TraitEngine, TraitEngineExt as _, TraitObligation}; @@ -30,10 +31,16 @@ impl<'tcx> ForestObligation for PendingPredicateObligation<'tcx> { /// as the `ParamEnv` can influence whether fulfillment succeeds /// or fails. type CacheKey = ty::ParamEnvAnd<'tcx, ty::Predicate<'tcx>>; + type Variable = TyOrConstInferVar; + type WatcherOffset = WatcherOffset; fn as_cache_key(&self) -> Self::CacheKey { self.obligation.param_env.and(self.obligation.predicate) } + + fn stalled_on(&self) -> &[Self::Variable] { + &self.stalled_on + } } /// The fulfillment context is used to drive trait resolution. It @@ -82,7 +89,7 @@ pub struct PendingPredicateObligation<'tcx> { // should mostly optimize for reading speed, while modifying is not as relevant. // // For whatever reason using a boxed slice is slower than using a `Vec` here. - pub stalled_on: Vec>, + pub stalled_on: Vec, } // `PendingPredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger. @@ -125,27 +132,22 @@ impl<'a, 'tcx> FulfillmentContext<'tcx> { let mut errors = Vec::new(); - loop { - debug!("select: starting another iteration"); + debug!("select: starting another iteration"); - // Process pending obligations. - let outcome: Outcome<_, _> = - self.predicates.process_obligations(&mut FulfillProcessor { - selcx, - register_region_obligations: self.register_region_obligations, - }); - debug!("select: outcome={:#?}", outcome); + // Process pending obligations. + let outcome = self.predicates.process_obligations( + &mut FulfillProcessor { + selcx, + register_region_obligations: self.register_region_obligations, + }, + DoCompleted::No, + ); + debug!("select: outcome={:#?}", outcome); - // FIXME: if we kept the original cache key, we could mark projection - // obligations as complete for the projection cache here. + // FIXME: if we kept the original cache key, we could mark projection + // obligations as complete for the projection cache here. - errors.extend(outcome.errors.into_iter().map(to_fulfillment_error)); - - // If nothing new was added, no need to keep looping. - if outcome.stalled { - break; - } - } + errors.extend(outcome.errors.into_iter().map(to_fulfillment_error)); debug!( "select({} predicates remaining, {} errors) done", @@ -212,7 +214,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> { .register_obligation(PendingPredicateObligation { obligation, stalled_on: vec![] }); } - fn select_all_or_error( + fn select_or_error( &mut self, infcx: &InferCtxt<'_, 'tcx>, ) -> Result<(), Vec>> { @@ -224,6 +226,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> { .into_iter() .map(to_fulfillment_error) .collect(); + if errors.is_empty() { Ok(()) } else { Err(errors) } } @@ -235,6 +238,12 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> { self.select(&mut selcx) } + fn deregister(&mut self, infcx: &InferCtxt<'_, 'tcx>) { + if let Some(offset) = self.predicates.take_watcher_offset() { + infcx.deregister_unify_watcher(offset); + } + } + fn pending_obligations(&self) -> Vec> { self.predicates.map_pending_obligations(|o| o.obligation.clone()) } @@ -263,7 +272,7 @@ impl<'a, 'b, 'tcx> ObligationProcessor for FulfillProcessor<'a, 'b, 'tcx> { /// This is always inlined, despite its size, because it has a single /// callsite and it is called *very* frequently. #[inline(always)] - fn process_obligation( + fn checked_process_obligation( &mut self, pending_obligation: &mut Self::Obligation, ) -> ProcessResult { @@ -294,7 +303,6 @@ impl<'a, 'b, 'tcx> ObligationProcessor for FulfillProcessor<'a, 'b, 'tcx> { })() } }; - if !change { debug!( "process_predicate: pending obligation {:?} still stalled on {:?}", @@ -304,35 +312,14 @@ impl<'a, 'b, 'tcx> ObligationProcessor for FulfillProcessor<'a, 'b, 'tcx> { return ProcessResult::Unchanged; } - self.progress_changed_obligations(pending_obligation) + self.process_obligation(pending_obligation) } - fn process_backedge<'c, I>( - &mut self, - cycle: I, - _marker: PhantomData<&'c PendingPredicateObligation<'tcx>>, - ) where - I: Clone + Iterator>, - { - if self.selcx.coinductive_match(cycle.clone().map(|s| s.obligation.predicate)) { - debug!("process_child_obligations: coinductive match"); - } else { - let cycle: Vec<_> = cycle.map(|c| c.obligation.clone()).collect(); - self.selcx.infcx().report_overflow_error_cycle(&cycle); - } - } -} - -impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { - // The code calling this method is extremely hot and only rarely - // actually uses this, so move this part of the code - // out of that loop. - #[inline(never)] - fn progress_changed_obligations( + fn process_obligation( &mut self, - pending_obligation: &mut PendingPredicateObligation<'tcx>, - ) -> ProcessResult, FulfillmentErrorCode<'tcx>> { - pending_obligation.stalled_on.truncate(0); + pending_obligation: &mut Self::Obligation, + ) -> ProcessResult { + pending_obligation.stalled_on.clear(); let obligation = &mut pending_obligation.obligation; @@ -344,6 +331,8 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { debug!(?obligation, ?obligation.cause, "process_obligation"); let infcx = self.selcx.infcx(); + let ty_or_const_var = + |v| infcx.root_ty_or_const(TyOrConstInferVar::maybe_from_ty(v).unwrap()); match obligation.predicate.kind() { ty::PredicateKind::ForAll(binder) => match binder.skip_binder() { @@ -453,8 +442,10 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { obligation.cause.span, ) { None => { - pending_obligation.stalled_on = - vec![TyOrConstInferVar::maybe_from_generic_arg(arg).unwrap()]; + pending_obligation.stalled_on.clear(); + pending_obligation.stalled_on.push(infcx.root_ty_or_const( + TyOrConstInferVar::maybe_from_generic_arg(arg).unwrap(), + )); ProcessResult::Unchanged } Some(os) => ProcessResult::Changed(mk_pending(os)), @@ -469,10 +460,9 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { ) { None => { // None means that both are unresolved. - pending_obligation.stalled_on = vec![ - TyOrConstInferVar::maybe_from_ty(subtype.a).unwrap(), - TyOrConstInferVar::maybe_from_ty(subtype.b).unwrap(), - ]; + pending_obligation.stalled_on.clear(); + pending_obligation.stalled_on.push(ty_or_const_var(subtype.a)); + pending_obligation.stalled_on.push(ty_or_const_var(subtype.b)); ProcessResult::Unchanged } Some(Ok(ok)) => ProcessResult::Changed(mk_pending(ok.obligations)), @@ -497,10 +487,12 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { ) { Ok(()) => ProcessResult::Changed(vec![]), Err(ErrorHandled::TooGeneric) => { - pending_obligation.stalled_on = substs - .iter() - .filter_map(|ty| TyOrConstInferVar::maybe_from_generic_arg(ty)) - .collect(); + pending_obligation.stalled_on.extend( + substs + .iter() + .filter_map(|ty| TyOrConstInferVar::maybe_from_generic_arg(ty)) + .map(|ty| infcx.root_ty_or_const(ty)), + ); ProcessResult::Unchanged } Err(e) => ProcessResult::Error(CodeSelectionError(ConstEvalFailure(e))), @@ -542,13 +534,11 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { ) { Ok(val) => Ok(Const::from_value(self.selcx.tcx(), val, c.ty)), Err(ErrorHandled::TooGeneric) => { - stalled_on.append( - &mut substs - .iter() - .filter_map(|arg| { - TyOrConstInferVar::maybe_from_generic_arg(arg) - }) - .collect(), + stalled_on.extend( + substs + .types() + .filter_map(|ty| TyOrConstInferVar::maybe_from_ty(ty)) + .map(|ty| infcx.root_ty_or_const(ty)), ); Err(ErrorHandled::TooGeneric) } @@ -600,12 +590,53 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { } } + fn process_backedge<'c, I>( + &mut self, + cycle: I, + _marker: PhantomData<&'c PendingPredicateObligation<'tcx>>, + ) where + I: Clone + Iterator>, + { + if self.selcx.coinductive_match(cycle.clone().map(|s| s.obligation.predicate)) { + debug!("process_child_obligations: coinductive match"); + } else { + let cycle: Vec<_> = cycle.map(|c| c.obligation.clone()).collect(); + self.selcx.infcx().report_overflow_error_cycle(&cycle); + } + } + + fn notify_unblocked( + &self, + offset: &WatcherOffset, + f: impl FnMut(::Variable), + ) { + let infcx = self.selcx.infcx(); + infcx.notify_watcher(offset, f); + } + + fn register_variable_watcher(&self) -> WatcherOffset { + self.selcx.infcx().register_unify_watcher() + } + + fn deregister_variable_watcher(&self, offset: WatcherOffset) { + self.selcx.infcx().deregister_unify_watcher(offset); + } + + fn watch_variable(&self, var: ::Variable) { + self.selcx.infcx().watch_variable(var); + } + fn unwatch_variable(&self, var: ::Variable) { + self.selcx.infcx().unwatch_variable(var); + } +} + +impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { #[instrument(level = "debug", skip(self, obligation, stalled_on))] fn process_trait_obligation( &mut self, obligation: &PredicateObligation<'tcx>, trait_obligation: TraitObligation<'tcx>, - stalled_on: &mut Vec>, + stalled_on: &mut Vec, ) -> ProcessResult, FulfillmentErrorCode<'tcx>> { let infcx = self.selcx.infcx(); if obligation.predicate.is_global() { @@ -632,10 +663,10 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { // only reason we can fail to make progress on // trait selection is because we don't have enough // information about the types in the trait. - *stalled_on = trait_ref_infer_vars( + stalled_on.extend(trait_ref_infer_vars( self.selcx, trait_obligation.predicate.map_bound(|pred| pred.trait_ref), - ); + )); debug!( "process_predicate: pending obligation {:?} now stalled on {:?}", @@ -656,16 +687,16 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { fn process_projection_obligation( &mut self, project_obligation: PolyProjectionObligation<'tcx>, - stalled_on: &mut Vec>, + stalled_on: &mut Vec, ) -> ProcessResult, FulfillmentErrorCode<'tcx>> { let tcx = self.selcx.tcx(); match project::poly_project_and_unify_type(self.selcx, &project_obligation) { Ok(Ok(Some(os))) => ProcessResult::Changed(mk_pending(os)), Ok(Ok(None)) => { - *stalled_on = trait_ref_infer_vars( + stalled_on.extend(trait_ref_infer_vars( self.selcx, project_obligation.predicate.to_poly_trait_ref(tcx), - ); + )); ProcessResult::Unchanged } // Let the caller handle the recursion @@ -678,14 +709,14 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> { } /// Returns the set of inference variables contained in a trait ref. -fn trait_ref_infer_vars<'a, 'tcx>( - selcx: &mut SelectionContext<'a, 'tcx>, +fn trait_ref_infer_vars<'a, 'tcx, 'b>( + selcx: &'b mut SelectionContext<'a, 'tcx>, trait_ref: ty::PolyTraitRef<'tcx>, -) -> Vec> { - selcx - .infcx() +) -> impl Iterator + 'b + Captures<'a> + Captures<'tcx> { + let infcx = selcx.infcx(); + infcx .resolve_vars_if_possible(&trait_ref) - .skip_binder() + .skip_binder() // ok b/c this check doesn't care about regions .substs .iter() // FIXME(eddyb) try using `skip_current_subtree` to skip everything that @@ -693,7 +724,7 @@ fn trait_ref_infer_vars<'a, 'tcx>( .filter(|arg| arg.has_infer_types_or_consts()) .flat_map(|arg| arg.walk()) .filter_map(TyOrConstInferVar::maybe_from_generic_arg) - .collect() + .map(move |var| infcx.root_ty_or_const(var)) } fn to_fulfillment_error<'tcx>( diff --git a/compiler/rustc_traits/src/dropck_outlives.rs b/compiler/rustc_traits/src/dropck_outlives.rs index 6cffa6d02a4e3..2d988cbcf1e54 100644 --- a/compiler/rustc_traits/src/dropck_outlives.rs +++ b/compiler/rustc_traits/src/dropck_outlives.rs @@ -75,7 +75,7 @@ fn dropck_outlives<'tcx>( // Set used to detect infinite recursion. let mut ty_set = FxHashSet::default(); - let mut fulfill_cx = TraitEngine::new(infcx.tcx); + let mut fulfill_cx = TraitEngine::new_with_deregister(infcx); let cause = ObligationCause::dummy(); let mut constraints = DtorckConstraint::empty(); diff --git a/compiler/rustc_typeck/src/check/check.rs b/compiler/rustc_typeck/src/check/check.rs index 8f2537404c5cc..03a160383474d 100644 --- a/compiler/rustc_typeck/src/check/check.rs +++ b/compiler/rustc_typeck/src/check/check.rs @@ -11,6 +11,7 @@ use rustc_hir::lang_items::LangItem; use rustc_hir::{ItemKind, Node}; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; use rustc_infer::infer::{RegionVariableOrigin, TyCtxtInferExt}; +use rustc_infer::traits::TraitEngine; use rustc_middle::ty::fold::TypeFoldable; use rustc_middle::ty::subst::GenericArgKind; use rustc_middle::ty::util::{Discr, IntTypeExt, Representability}; @@ -616,7 +617,7 @@ fn check_opaque_meets_bounds<'tcx>( // Check that all obligations are satisfied by the implementation's // version. - if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_all_or_error(&infcx) { + if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_or_error(&infcx) { infcx.report_fulfillment_errors(errors, None, false); } diff --git a/compiler/rustc_typeck/src/check/compare_method.rs b/compiler/rustc_typeck/src/check/compare_method.rs index 4acc7451a2131..7141c9a4fb183 100644 --- a/compiler/rustc_typeck/src/check/compare_method.rs +++ b/compiler/rustc_typeck/src/check/compare_method.rs @@ -320,7 +320,7 @@ fn compare_predicate_entailment<'tcx>( // Check that all obligations are satisfied by the implementation's // version. - if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_all_or_error(&infcx) { + if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_or_error(&infcx) { infcx.report_fulfillment_errors(errors, None, false); return Err(ErrorReported); } @@ -1028,7 +1028,7 @@ crate fn compare_const_impl<'tcx>( // Check that all obligations are satisfied by the implementation's // version. - if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_all_or_error(&infcx) { + if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_or_error(&infcx) { infcx.report_fulfillment_errors(errors, None, false); return; } @@ -1144,7 +1144,7 @@ fn compare_type_predicate_entailment<'tcx>( // Check that all obligations are satisfied by the implementation's // version. - if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_all_or_error(&infcx) { + if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_or_error(&infcx) { infcx.report_fulfillment_errors(errors, None, false); return Err(ErrorReported); } @@ -1272,7 +1272,7 @@ pub fn check_type_bounds<'tcx>( // Check that all obligations are satisfied by the implementation's // version. - if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_all_or_error(&infcx) { + if let Err(ref errors) = inh.fulfillment_cx.borrow_mut().select_or_error(&infcx) { infcx.report_fulfillment_errors(errors, None, false); return Err(ErrorReported); } diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs b/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs index f87e6b607d46e..2ca1ed2a9d91d 100644 --- a/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs +++ b/compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs @@ -691,7 +691,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { pub(in super::super) fn select_all_obligations_or_error(&self) { debug!("select_all_obligations_or_error"); - if let Err(errors) = self.fulfillment_cx.borrow_mut().select_all_or_error(&self) { + if let Err(errors) = self.fulfillment_cx.borrow_mut().select_or_error(&self) { self.report_fulfillment_errors(&errors, self.inh.body_id, false); } } diff --git a/compiler/rustc_typeck/src/check/inherited.rs b/compiler/rustc_typeck/src/check/inherited.rs index 7e580485c3de4..a6e81be98f21a 100644 --- a/compiler/rustc_typeck/src/check/inherited.rs +++ b/compiler/rustc_typeck/src/check/inherited.rs @@ -33,7 +33,7 @@ pub struct Inherited<'a, 'tcx> { pub(super) locals: RefCell>>, - pub(super) fulfillment_cx: RefCell>>, + pub(super) fulfillment_cx: RefCell + 'tcx>>, // Some additional `Sized` obligations badly affect type inference. // These obligations are added in a later stage of typeck. @@ -70,6 +70,12 @@ pub struct Inherited<'a, 'tcx> { pub(super) body_id: Option, } +impl<'a, 'tcx> Drop for Inherited<'a, 'tcx> { + fn drop(&mut self) { + self.fulfillment_cx.get_mut().deregister(&self.infcx); + } +} + impl<'a, 'tcx> Deref for Inherited<'a, 'tcx> { type Target = InferCtxt<'a, 'tcx>; fn deref(&self) -> &Self::Target { diff --git a/src/test/ui/obligation-forest-bug-69218-2.rs b/src/test/ui/obligation-forest-bug-69218-2.rs new file mode 100644 index 0000000000000..3ae2fdce930a4 --- /dev/null +++ b/src/test/ui/obligation-forest-bug-69218-2.rs @@ -0,0 +1,27 @@ +// run-pass +#![allow(dead_code)] + +use std::borrow::Cow; + +pub type Result = std::result::Result; + +pub struct CompressedData<'data> { + pub format: CompressionFormat, + pub data: &'data [u8], +} + +pub enum CompressionFormat { + None, + Unknown, +} + +impl<'data> CompressedData<'data> { + pub fn decompress(self) -> Result> { + match self.format { + CompressionFormat::None => Ok(Cow::Borrowed(self.data)), + _ => Err("Unsupported compressed data."), + } + } +} + +fn main() {} diff --git a/src/test/ui/obligation-forest-bug-69218.rs b/src/test/ui/obligation-forest-bug-69218.rs new file mode 100644 index 0000000000000..f8603d63b0514 --- /dev/null +++ b/src/test/ui/obligation-forest-bug-69218.rs @@ -0,0 +1,66 @@ +// run-pass +#![allow(dead_code)] + +use std::marker::PhantomData; + +pub trait Consumer { + type Result; +} + +pub trait IndexedParallelIterator: ExactSizeIterator { + type Item; +} + +pub struct CollectConsumer<'c, T: Send> { + target: &'c mut [T], +} + +impl<'c, T: Send + 'c> Consumer for CollectConsumer<'c, T> { + type Result = CollectResult<'c, T>; +} + +pub struct CollectResult<'c, T> { + start: *mut T, + len: usize, + invariant_lifetime: PhantomData<&'c mut &'c mut [T]>, +} + +unsafe impl<'c, T> Send for CollectResult<'c, T> where T: Send {} + +pub fn unzip_indexed(_: I, _: CA) -> CA::Result +where + I: IndexedParallelIterator, + CA: Consumer, +{ + unimplemented!() +} + +struct Collect<'c, T: Send> { + vec: &'c mut Vec, + len: usize, +} + +pub fn unzip_into_vecs(pi: I, left: &mut Vec, _: &mut Vec) +where + I: IndexedParallelIterator, + A: Send, + B: Send, +{ + let len = pi.len(); + Collect::new(left, len).with_consumer(|left_consumer| unzip_indexed(pi, left_consumer)); +} + +impl<'c, T: Send + 'c> Collect<'c, T> { + fn new(vec: &'c mut Vec, len: usize) -> Self { + Collect { vec, len } + } + + fn with_consumer(self, _: F) + where + F: FnOnce(CollectConsumer) -> CollectResult, + { + unimplemented!() + } +} + +fn main() {}