diff --git a/zcash_client_memory/src/error.rs b/zcash_client_memory/src/error.rs index 8548404a56..08bea0e1e1 100644 --- a/zcash_client_memory/src/error.rs +++ b/zcash_client_memory/src/error.rs @@ -1,6 +1,6 @@ use zcash_keys::keys::{AddressGenerationError, DerivationError}; use zcash_primitives::transaction::TxId; -use zcash_protocol::memo; +use zcash_protocol::{consensus::BlockHeight, memo}; use crate::mem_wallet::AccountId; @@ -36,6 +36,10 @@ pub enum Error { CorruptedData(String), #[error("An error occurred while processing an account due to a failure in deriving the account's keys: {0}")] BadAccountData(String), + #[error("Blocks are non sequental")] + NonSequentialBlocks, + #[error("Invalid scan range start {0}, end {1}: {2}")] + InvalidScanRange(BlockHeight, BlockHeight, String), #[error("Other error: {0}")] Other(String), } diff --git a/zcash_client_memory/src/mem_wallet/mod.rs b/zcash_client_memory/src/mem_wallet/mod.rs index 20b8f44dc5..d1e359246e 100644 --- a/zcash_client_memory/src/mem_wallet/mod.rs +++ b/zcash_client_memory/src/mem_wallet/mod.rs @@ -2,6 +2,7 @@ use core::time; use incrementalmerkletree::{Address, Marking, Retention}; use sapling::NullifierDerivingKey; +use scanning::ScanQueue; use secrecy::{ExposeSecret, SecretVec}; use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree}; use std::{ @@ -55,6 +56,7 @@ use zcash_client_backend::{data_api::ORCHARD_SHARD_HEIGHT, wallet::WalletOrchard use crate::error::Error; +mod scanning; mod tables; mod wallet_commitment_trees; mod wallet_read; @@ -68,6 +70,12 @@ struct MemoryWalletBlock { // Just the transactions that involve an account in this wallet transactions: HashSet, memos: HashMap, + sapling_commitment_tree_size: Option, + sapling_output_count: Option, + #[cfg(feature = "orchard")] + orchard_commitment_tree_size: Option, + #[cfg(feature = "orchard")] + orchard_action_count: Option, } pub struct MemoryWalletDb { @@ -83,6 +91,8 @@ pub struct MemoryWalletDb { tx_locator: TxLocatorMap, + scan_queue: ScanQueue, + sapling_tree: ShardTree< MemoryShardStore, { SAPLING_SHARD_HEIGHT * 2 }, @@ -109,6 +119,7 @@ impl MemoryWalletDb { nullifiers: NullifierMap::new(), tx_locator: TxLocatorMap::new(), receieved_note_spends: ReceievdNoteSpends::new(), + scan_queue: ScanQueue::new(), } } fn mark_sapling_note_spent(&mut self, nf: sapling::Nullifier, txid: TxId) -> Result<(), Error> { @@ -124,9 +135,6 @@ impl MemoryWalletDb { Ok(()) } - // fn get_account(&self, account_id: AccountId) -> Option<&Account> { - // self.accounts.get(*account_id as usize) - // } fn get_account_mut(&mut self, account_id: AccountId) -> Option<&mut Account> { self.accounts.get_mut(*account_id as usize) } diff --git a/zcash_client_memory/src/mem_wallet/scanning.rs b/zcash_client_memory/src/mem_wallet/scanning.rs new file mode 100644 index 0000000000..5f48f544ee --- /dev/null +++ b/zcash_client_memory/src/mem_wallet/scanning.rs @@ -0,0 +1,215 @@ +#![allow(unused)] +use core::time; +use incrementalmerkletree::{Address, Marking, Position, Retention}; +use sapling::NullifierDerivingKey; +use secrecy::{ExposeSecret, SecretVec}; +use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree}; +use std::{ + cell::RefCell, + cmp::Ordering, + collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, + convert::Infallible, + hash::Hash, + num::NonZeroU32, + ops::{Deref, DerefMut, Range}, + path::Iter, + rc::Rc, +}; +use zcash_keys::keys::{AddressGenerationError, DerivationError, UnifiedIncomingViewingKey}; +use zip32::{fingerprint::SeedFingerprint, DiversifierIndex, Scope}; + +use zcash_primitives::{ + block::{self, BlockHash}, + consensus::{BlockHeight, Network}, + transaction::{components::OutPoint, txid, Authorized, Transaction, TransactionData, TxId}, +}; +use zcash_protocol::{ + memo::{self, Memo, MemoBytes}, + value::{ZatBalance, Zatoshis}, + PoolType, + ShieldedProtocol::{self, Orchard, Sapling}, +}; + +use zcash_client_backend::{ + address::UnifiedAddress, + data_api::{ + chain::ChainState, + scanning::{spanning_tree::SpanningTree, ScanPriority}, + Account as _, AccountPurpose, AccountSource, SeedRelevance, SentTransactionOutput, + TransactionDataRequest, TransactionStatus, + }, + keys::{UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey}, + wallet::{ + Note, NoteId, Recipient, WalletSaplingOutput, WalletSpend, WalletTransparentOutput, + WalletTx, + }, +}; + +use zcash_client_backend::data_api::{ + chain::CommitmentTreeRoot, scanning::ScanRange, AccountBirthday, BlockMetadata, + DecryptedTransaction, NullifierQuery, ScannedBlock, SentTransaction, WalletCommitmentTrees, + WalletRead, WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT, +}; + +use super::AccountId; + +#[cfg(feature = "transparent-inputs")] +use { + zcash_client_backend::wallet::TransparentAddressMetadata, + zcash_primitives::legacy::TransparentAddress, +}; + +#[cfg(feature = "orchard")] +use { + zcash_client_backend::data_api::ORCHARD_SHARD_HEIGHT, + zcash_client_backend::wallet::WalletOrchardOutput, +}; + +use crate::error::Error; + +/// A queue of scanning ranges. Contains the start and end heights of each range, along with the +/// priority of scanning that range. +pub struct ScanQueue(Vec<(BlockHeight, BlockHeight, ScanPriority)>); + +impl ScanQueue { + pub fn new() -> Self { + ScanQueue(Vec::new()) + } + + pub fn suggest_scan_ranges(&self, min_priority: ScanPriority) -> Vec { + let mut priorities: Vec<_> = self + .0 + .iter() + .filter(|(_, _, p)| *p >= min_priority) + .collect(); + priorities.sort_by(|(_, _, a), (_, _, b)| b.cmp(a)); + + priorities + .into_iter() + .map(|(start, end, priority)| { + let range = Range { + start: *start, + end: *end, + }; + ScanRange::from_parts(range, *priority) + }) + .collect() + } + pub fn insert_queue_entries<'a>( + &mut self, + entries: impl Iterator, + ) -> Result<(), Error> { + for entry in entries { + if entry.block_range().start >= entry.block_range().end { + return Err(Error::InvalidScanRange( + entry.block_range().start, + entry.block_range().end, + "start must be less than end".to_string(), + )); + } + + for (start, end, _) in &self.0 { + if *start == entry.block_range().start || *end == entry.block_range().end { + return Err(Error::InvalidScanRange( + entry.block_range().start, + entry.block_range().end, + "at least part of range is already covered by another range".to_string(), + )); + } + } + + self.0.push(( + entry.block_range().start, + entry.block_range().end, + entry.priority(), + )); + } + Ok(()) + } + pub fn replace_queue_entries( + &mut self, + query_range: &Range, + entries: impl Iterator, + force_rescans: bool, + ) -> Result<(), Error> { + let (to_create, to_delete_ends) = { + let mut q_ranges: Vec<_> = self + .0 + .iter() + .filter(|(start, end, _)| { + // Ignore ranges that do not overlap and are not adjacent to the query range. + !(start > &query_range.end || &query_range.start > end) + }) + .collect(); + q_ranges.sort_by(|(_, end_a, _), (_, end_b, _)| end_a.cmp(end_b)); + + // Iterate over the ranges in the scan queue that overlap the range that we have + // identified as needing to be fully scanned. For each such range add it to the + // spanning tree (these should all be nonoverlapping ranges, but we might coalesce + // some in the process). + let mut to_create: Option = None; + let mut to_delete_ends: Vec = vec![]; + + let mut q_ranges = q_ranges.into_iter(); + while let Some((start, end, priority)) = q_ranges.next() { + let entry = ScanRange::from_parts( + Range { + start: *start, + end: *end, + }, + *priority, + ); + to_delete_ends.push(entry.block_range().end); + to_create = if let Some(cur) = to_create { + Some(cur.insert(entry, force_rescans)) + } else { + Some(SpanningTree::Leaf(entry)) + }; + } + + // Update the tree that we read from the database, or if we didn't find any ranges + // start with the scanned range. + for entry in entries { + to_create = if let Some(cur) = to_create { + Some(cur.insert(entry, force_rescans)) + } else { + Some(SpanningTree::Leaf(entry)) + }; + } + (to_create, to_delete_ends) + }; + + if let Some(tree) = to_create { + self.0.retain(|(_, block_range_end, _)| { + // if the block_range_end is equal to any in to_delete_ends, remove it + !to_delete_ends.contains(block_range_end) + }); + let scan_ranges = tree.into_vec(); + self.insert_queue_entries(scan_ranges.iter()); + } + Ok(()) + } +} + +impl IntoIterator for ScanQueue { + type Item = (BlockHeight, BlockHeight, ScanPriority); + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +// We deref to slice so that we can reuse the slice impls +impl Deref for ScanQueue { + type Target = [(BlockHeight, BlockHeight, ScanPriority)]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} +impl DerefMut for ScanQueue { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0[..] + } +} diff --git a/zcash_client_memory/src/mem_wallet/wallet_read.rs b/zcash_client_memory/src/mem_wallet/wallet_read.rs index 8b82987696..c25ae8e6b9 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_read.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_read.rs @@ -18,8 +18,8 @@ use std::ops::Add; use zcash_client_backend::{ address::UnifiedAddress, data_api::{ - chain::ChainState, Account as _, AccountPurpose, AccountSource, SeedRelevance, - TransactionDataRequest, TransactionStatus, + chain::ChainState, scanning::ScanPriority, Account as _, AccountPurpose, AccountSource, + SeedRelevance, TransactionDataRequest, TransactionStatus, }, keys::{UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey}, wallet::{NoteId, WalletSpend, WalletTransparentOutput, WalletTx}, @@ -50,7 +50,7 @@ use { }; use super::{Account, AccountId, MemoryWalletDb}; -use crate::error::Error; +use crate::{error::Error, mem_wallet::MemoryWalletBlock}; impl WalletRead for MemoryWalletDb { type Error = Error; @@ -214,7 +214,13 @@ impl WalletRead for MemoryWalletDb { } fn chain_height(&self) -> Result, Self::Error> { - todo!() + Ok(self + .scan_queue + .iter() + .max_by(|(_, end_a, _), (_, end_b, _)| end_a.cmp(end_b)) + // Scan ranges are end-exclusive, so we subtract 1 from `max_height` to obtain the + // height of the last known chain tip; + .and_then(|(_, end, _)| Some(end.saturating_sub(1)))) } fn get_block_hash(&self, block_height: BlockHeight) -> Result, Self::Error> { @@ -227,24 +233,94 @@ impl WalletRead for MemoryWalletDb { })) } - fn block_metadata(&self, _height: BlockHeight) -> Result, Self::Error> { - todo!() + fn block_metadata(&self, height: BlockHeight) -> Result, Self::Error> { + Ok(self.blocks.get(&height).map(|block| { + let MemoryWalletBlock { + height, + hash, + sapling_commitment_tree_size, + #[cfg(feature = "orchard")] + orchard_commitment_tree_size, + .. + } = block; + // TODO: Deal with legacy sapling trees + BlockMetadata::from_parts( + *height, + *hash, + *sapling_commitment_tree_size, + #[cfg(feature = "orchard")] + *orchard_commitment_tree_size, + ) + })) } fn block_fully_scanned(&self) -> Result, Self::Error> { - todo!() + if let Some(birthday_height) = self.get_wallet_birthday()? { + // We assume that the only way we get a contiguous range of block heights in the `blocks` table + // starting with the birthday block, is if all scanning operations have been performed on those + // blocks. This holds because the `blocks` table is only altered by `WalletDb::put_blocks` via + // `put_block`, and the effective combination of intra-range linear scanning and the nullifier + // map ensures that we discover all wallet-related information within the contiguous range. + // + // We also assume that every contiguous range of block heights in the `blocks` table has a + // single matching entry in the `scan_queue` table with priority "Scanned". This requires no + // bugs in the scan queue update logic, which we have had before. However, a bug here would + // mean that we return a more conservative fully-scanned height, which likely just causes a + // performance regression. + // + // The fully-scanned height is therefore the last height that falls within the first range in + // the scan queue with priority "Scanned". + // SQL query problems. + + let mut scanned_ranges: Vec<_> = self + .scan_queue + .iter() + .filter(|(_, _, p)| p == &ScanPriority::Scanned) + .collect(); + scanned_ranges.sort_by(|(start_a, _, _), (start_b, _, _)| start_a.cmp(start_b)); + if let Some(fully_scanned_height) = + scanned_ranges + .first() + .and_then(|(block_range_start, block_range_end, priority)| { + // If the start of the earliest scanned range is greater than + // the birthday height, then there is an unscanned range between + // the wallet birthday and that range, so there is no fully + // scanned height. + if *block_range_start <= birthday_height { + // Scan ranges are end-exclusive. + Some(*block_range_end - 1) + } else { + None + } + }) + { + self.block_metadata(fully_scanned_height) + } else { + Ok(None) + } + } else { + Ok(None) + } } fn get_max_height_hash(&self) -> Result, Self::Error> { - todo!() + Ok(self + .blocks + .last_key_value() + .map(|(height, block)| (*height, block.hash))) } fn block_max_scanned(&self) -> Result, Self::Error> { - todo!() + Ok(self + .blocks + .last_key_value() + .map(|(height, _)| self.block_metadata(*height)) + .transpose()? + .flatten()) } fn suggest_scan_ranges(&self) -> Result, Self::Error> { - Ok(vec![]) + Ok(self.scan_queue.suggest_scan_ranges(ScanPriority::Historic)) } fn get_target_and_anchor_heights( diff --git a/zcash_client_memory/src/mem_wallet/wallet_write.rs b/zcash_client_memory/src/mem_wallet/wallet_write.rs index ae195904d7..54fa0bb5e7 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_write.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_write.rs @@ -117,68 +117,77 @@ impl WalletWrite for MemoryWalletDb { // - Make sure blocks are coming in order. // - Make sure the first block in the sequence is tip + 1? // - Add a check to make sure the blocks are not already in the data store. + let start_height = blocks.first().map(|b| b.height()); + let mut last_scanned_height = None; + for block in blocks.into_iter() { let mut transactions = HashMap::new(); let mut memos = HashMap::new(); + if last_scanned_height + .iter() + .any(|prev| block.height() != *prev + 1) + { + return Err(Error::NonSequentialBlocks); + } + for transaction in block.transactions().iter() { let txid = transaction.txid(); // Mark the Sapling nullifiers of the spent notes as spent in the `sapling_spends` map. - transaction - .sapling_spends() - .iter() - .map(|s| self.mark_sapling_note_spent(*s.nf(), txid)); + for spend in transaction.sapling_spends() { + self.mark_sapling_note_spent(*spend.nf(), txid); + } - #[cfg(feature = "orchard")] // Mark the Orchard nullifiers of the spent notes as spent in the `orchard_spends` map. - transaction - .orchard_spends() - .iter() - .map(|s| self.mark_orchard_note_spent(*s.nf(), txid)); + #[cfg(feature = "orchard")] + for spend in transaction.orchard_spends() { + self.mark_orchard_note_spent(*spend.nf(), txid); + } - transaction.sapling_outputs().iter().map(|o| { + for output in transaction.sapling_outputs() { // Insert the memo into the `memos` map. let note_id = NoteId::new( txid, Sapling, - u16::try_from(o.index()).expect("output indices are representable as u16"), + u16::try_from(output.index()) + .expect("output indices are representable as u16"), ); if let Ok(Some(memo)) = self.get_memo(note_id) { memos.insert(note_id, memo.encode()); } // Check whether this note was spent in a later block range that // we previously scanned. - let spent_in = o + let spent_in = output .nf() .and_then(|nf| self.nullifiers.get(&Nullifier::Sapling(*nf))) .and_then(|(height, tx_idx)| self.tx_locator.get(*height, *tx_idx)) .map(|x| *x); - self.insert_received_sapling_note(note_id, &o, spent_in); - }); + self.insert_received_sapling_note(note_id, &output, spent_in); + } #[cfg(feature = "orchard")] - transaction.orchard_outputs().iter().map(|o| { + for output in transaction.orchard_outputs().iter() { // Insert the memo into the `memos` map. let note_id = NoteId::new( txid, Orchard, - u16::try_from(o.index()).expect("output indices are representable as u16"), + u16::try_from(output.index()) + .expect("output indices are representable as u16"), ); if let Ok(Some(memo)) = self.get_memo(note_id) { memos.insert(note_id, memo.encode()); } // Check whether this note was spent in a later block range that // we previously scanned. - let spent_in = o + let spent_in = output .nf() .and_then(|nf| self.nullifiers.get(&&Nullifier::Orchard(*nf))) .and_then(|(height, tx_idx)| self.tx_locator.get(*height, *tx_idx)) .map(|x| *x); - self.insert_received_orchard_note(note_id, &o, spent_in) - }); - + self.insert_received_orchard_note(note_id, &output, spent_in) + } // Add frontier to the sapling tree self.sapling_tree.insert_frontier( from_state.final_sapling_tree().clone(), @@ -197,7 +206,7 @@ impl WalletWrite for MemoryWalletDb { marking: Marking::Reference, }, ); - + last_scanned_height = Some(block.height()); transactions.insert(txid, transaction.clone()); } @@ -212,12 +221,20 @@ impl WalletWrite for MemoryWalletDb { block_time: block.block_time(), transactions: transactions.keys().cloned().collect(), memos, + sapling_commitment_tree_size: Some(block.sapling().final_tree_size()), + sapling_output_count: Some(block.sapling().commitments().len().try_into().unwrap()), + #[cfg(feature = "orchard")] + orchard_commitment_tree_size: Some(block.orchard().final_tree_size()), + #[cfg(feature = "orchard")] + orchard_action_count: Some(block.orchard().commitments().len().try_into().unwrap()), }; + // Insert transaction metadata into the transaction table transactions .into_iter() .for_each(|(_id, tx)| self.tx_table.put_tx_meta(tx, block.height())); + // Insert the block into the block map self.blocks.insert(block.height(), memory_block); // Add the Sapling commitments to the sapling tree. @@ -240,6 +257,7 @@ impl WalletWrite for MemoryWalletDb { .batch_insert(start_position, block_commitments.orchard.into_iter()); } } + // We can do some pruning of the tx_locator_map here Ok(()) }