diff --git a/Cargo.lock b/Cargo.lock index 50c729b871..5117c99e68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6157,8 +6157,7 @@ dependencies = [ [[package]] name = "zip32" version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4226d0aee9c9407c27064dfeec9d7b281c917de3374e1e5a2e2cfad9e09de19e" +source = "git+https://github.com/zcash/zip32.git?branch=diversifier_index_ord#38e39b7086bbd5747dc61a84faf54ec9a58fa535" dependencies = [ "blake2b_simd", "memuse", diff --git a/Cargo.toml b/Cargo.toml index 22de11d26c..5c957f738c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -157,3 +157,6 @@ zip32 = "0.1.1" lto = true panic = 'abort' codegen-units = 1 + +[patch.crates-io] +zip32 = { git = "https://github.com/zcash/zip32.git", branch = "diversifier_index_ord"} \ No newline at end of file diff --git a/zcash_client_memory/src/error.rs b/zcash_client_memory/src/error.rs index c65910b116..8548404a56 100644 --- a/zcash_client_memory/src/error.rs +++ b/zcash_client_memory/src/error.rs @@ -16,6 +16,8 @@ pub enum Error { MemoDecryption(memo::Error), #[error("Error deriving key: {0}")] KeyDerivation(DerivationError), + #[error("Unknown ZIP32 derivation ")] + UnknownZip32Derivation, #[error("Error generating address: {0}")] AddressGeneration(Type), #[error("Seed must be between 32 and 252 bytes in length.")] @@ -32,6 +34,8 @@ pub enum Error { IoError(std::io::Error), #[error("Corrupted Data: {0}")] CorruptedData(String), + #[error("An error occurred while processing an account due to a failure in deriving the account's keys: {0}")] + BadAccountData(String), #[error("Other error: {0}")] Other(String), } diff --git a/zcash_client_memory/src/mem_wallet/mod.rs b/zcash_client_memory/src/mem_wallet/mod.rs index 68f8d79e19..20b8f44dc5 100644 --- a/zcash_client_memory/src/mem_wallet/mod.rs +++ b/zcash_client_memory/src/mem_wallet/mod.rs @@ -124,6 +124,13 @@ impl MemoryWalletDb { Ok(()) } + // fn get_account(&self, account_id: AccountId) -> Option<&Account> { + // self.accounts.get(*account_id as usize) + // } + fn get_account_mut(&mut self, account_id: AccountId) -> Option<&mut Account> { + self.accounts.get_mut(*account_id as usize) + } + #[cfg(feature = "orchard")] fn mark_orchard_note_spent( &mut self, @@ -280,11 +287,41 @@ pub struct Account { kind: AccountSource, viewing_key: ViewingKey, birthday: AccountBirthday, - purpose: AccountPurpose, + purpose: AccountPurpose, // TODO: Remove this. AccountSource should be sufficient. + addresses: BTreeMap, notes: HashSet, } impl Account { + fn new( + account_id: AccountId, + kind: AccountSource, + viewing_key: ViewingKey, + birthday: AccountBirthday, + purpose: AccountPurpose, + ) -> Result { + let mut acc = Self { + account_id, + kind, + viewing_key, + birthday, + purpose, + addresses: BTreeMap::new(), + notes: HashSet::new(), + }; + let ua_request = acc + .viewing_key + .uivk() + .to_address_request() + .and_then(|ua_request| ua_request.intersect(&UnifiedAddressRequest::all().unwrap())) + .ok_or_else(|| { + Error::AddressGeneration(AddressGenerationError::ShieldedReceiverRequired) + })?; + + let (addr, diversifier_index) = acc.default_address(ua_request)?; + acc.addresses.insert(diversifier_index, addr); + Ok(acc) + } /// Returns the default Unified Address for the account, /// along with the diversifier index that generated it. /// @@ -300,6 +337,44 @@ impl Account { fn birthday(&self) -> &AccountBirthday { &self.birthday } + + fn addresses(&self) -> &BTreeMap { + &self.addresses + } + + fn current_address(&self) -> Option<(DiversifierIndex, UnifiedAddress)> { + self.addresses + .last_key_value() + .map(|(diversifier_index, address)| (*diversifier_index, address.clone())) + } + fn kind(&self) -> &AccountSource { + &self.kind + } + fn viewing_key(&self) -> &ViewingKey { + &self.viewing_key + } + fn next_available_address( + &mut self, + request: UnifiedAddressRequest, + ) -> Result, Error> { + match self.ufvk() { + Some(ufvk) => { + let search_from = match self.current_address() { + Some((mut last_diversifier_index, _)) => { + last_diversifier_index + .increment() + .map_err(|_| AddressGenerationError::DiversifierSpaceExhausted)?; + last_diversifier_index + } + None => DiversifierIndex::default(), + }; + let (addr, diversifier_index) = ufvk.find_address(search_from, request)?; + self.addresses.insert(diversifier_index, addr.clone()); + Ok(Some(addr)) + } + None => Ok(None), + } + } } impl zcash_client_backend::data_api::Account for Account { diff --git a/zcash_client_memory/src/mem_wallet/wallet_read.rs b/zcash_client_memory/src/mem_wallet/wallet_read.rs index 47e332b5dd..8b82987696 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_read.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_read.rs @@ -1,8 +1,10 @@ use incrementalmerkletree::{Address, Marking, Retention}; +use nonempty::NonEmpty; use sapling::NullifierDerivingKey; use secrecy::{ExposeSecret, SecretVec}; use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree}; use std::{ + clone, cmp::Ordering, collections::{BTreeMap, HashMap, HashSet}, convert::Infallible, @@ -28,7 +30,7 @@ use zcash_primitives::{ transaction::{Transaction, TransactionData, TxId}, }; use zcash_protocol::{ - consensus::BranchId, + consensus::{self, BranchId}, memo::{self, Memo, MemoBytes}, value::Zatoshis, ShieldedProtocol::{Orchard, Sapling}, @@ -56,7 +58,7 @@ impl WalletRead for MemoryWalletDb { type Account = Account; fn get_account_ids(&self) -> Result, Self::Error> { - Ok(Vec::new()) + Ok(self.accounts.iter().map(|a| a.id()).collect()) } fn get_account( @@ -68,25 +70,98 @@ impl WalletRead for MemoryWalletDb { fn get_derived_account( &self, - _seed: &SeedFingerprint, - _account_id: zip32::AccountId, + seed: &SeedFingerprint, + account_id: zip32::AccountId, ) -> Result, Self::Error> { - todo!() + Ok(self.accounts.iter().find_map(|acct| match acct.kind() { + AccountSource::Derived { + seed_fingerprint, + account_index, + } => { + if seed_fingerprint == seed && account_index == &account_id { + Some(acct.clone()) + } else { + None + } + } + AccountSource::Imported { purpose } => None, + })) } fn validate_seed( &self, - _account_id: Self::AccountId, - _seed: &SecretVec, + account_id: Self::AccountId, + seed: &SecretVec, ) -> Result { - todo!() + if let Some(account) = self.get_account(account_id)? { + if let AccountSource::Derived { + seed_fingerprint, + account_index, + } = account.source() + { + seed_matches_derived_account( + &self.network, + seed, + &seed_fingerprint, + account_index, + &account.uivk(), + ) + } else { + Err(Error::UnknownZip32Derivation) + } + } else { + // Missing account is documented to return false. + Ok(false) + } } fn seed_relevance_to_derived_accounts( &self, seed: &SecretVec, ) -> Result, Self::Error> { - todo!() + let mut has_accounts = false; + let mut has_derived = false; + let mut relevant_account_ids = vec![]; + + for account_id in self.get_account_ids()? { + has_accounts = true; + let account = self.get_account(account_id)?.expect("account ID exists"); + + // If the account is imported, the seed _might_ be relevant, but the only + // way we could determine that is by brute-forcing the ZIP 32 account + // index space, which we're not going to do. The method name indicates to + // the caller that we only check derived accounts. + if let AccountSource::Derived { + seed_fingerprint, + account_index, + } = account.source() + { + has_derived = true; + + if seed_matches_derived_account( + &self.network, + seed, + &seed_fingerprint, + account_index, + &account.uivk(), + )? { + // The seed is relevant to this account. + relevant_account_ids.push(account_id); + } + } + } + + Ok( + if let Some(account_ids) = NonEmpty::from_vec(relevant_account_ids) { + SeedRelevance::Relevant { account_ids } + } else if has_derived { + SeedRelevance::NotRelevant + } else if has_accounts { + SeedRelevance::NoDerivedAccounts + } else { + SeedRelevance::NoAccounts + }, + ) } fn get_account_for_ufvk( @@ -110,20 +185,10 @@ impl WalletRead for MemoryWalletDb { &self, account: Self::AccountId, ) -> Result, Self::Error> { - self.accounts - .get(*account as usize) - .map(|account| { - account - .ufvk() - .unwrap() - .default_address( - UnifiedAddressRequest::all() - .expect("At least one protocol should be enabled."), - ) - .map(|(addr, _)| addr) - }) - .transpose() - .map_err(|e| e.into()) + Ok(self + .get_account(account)? + .and_then(|account| Account::current_address(&account)) + .map(|(_, a)| a.clone())) } fn get_account_birthday(&self, account: Self::AccountId) -> Result { @@ -350,3 +415,44 @@ impl WalletRead for MemoryWalletDb { todo!() } } + +/// Copied from zcash_client_sqlite::wallet::seed_matches_derived_account +fn seed_matches_derived_account( + params: &P, + seed: &SecretVec, + seed_fingerprint: &SeedFingerprint, + account_index: zip32::AccountId, + uivk: &UnifiedIncomingViewingKey, +) -> Result { + let seed_fingerprint_match = + &SeedFingerprint::from_seed(seed.expose_secret()).ok_or_else(|| { + Error::BadAccountData("Seed must be between 32 and 252 bytes in length.".to_owned()) + })? == seed_fingerprint; + + // Keys are not comparable with `Eq`, but addresses are, so we derive what should + // be equivalent addresses for each key and use those to check for key equality. + let uivk_match = + match UnifiedSpendingKey::from_seed(params, &seed.expose_secret()[..], account_index) { + // If we can't derive a USK from the given seed with the account's ZIP 32 + // account index, then we immediately know the UIVK won't match because wallet + // accounts are required to have a known UIVK. + Err(_) => false, + Ok(usk) => { + UnifiedAddressRequest::all().map_or(Ok::<_, Error>(false), |ua_request| { + Ok(usk + .to_unified_full_viewing_key() + .default_address(ua_request)? + == uivk.default_address(ua_request)?) + })? + } + }; + + if seed_fingerprint_match != uivk_match { + // If these mismatch, it suggests database corruption. + Err(Error::CorruptedData(format!( + "Seed fingerprint match: {seed_fingerprint_match}, uivk match: {uivk_match}" + ))) + } else { + Ok(seed_fingerprint_match && uivk_match) + } +} diff --git a/zcash_client_memory/src/mem_wallet/wallet_write.rs b/zcash_client_memory/src/mem_wallet/wallet_write.rs index 1a01ccfc00..ae195904d7 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_write.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_write.rs @@ -72,17 +72,18 @@ impl WalletWrite for MemoryWalletDb { let usk = UnifiedSpendingKey::from_seed(&self.network, seed.expose_secret(), account_index)?; let ufvk = usk.to_unified_full_viewing_key(); - let account = Account { - account_id: AccountId(self.accounts.len() as u32), - kind: AccountSource::Derived { + + let account = Account::new( + AccountId(self.accounts.len() as u32), + AccountSource::Derived { seed_fingerprint, account_index, }, - viewing_key: ViewingKey::Full(Box::new(ufvk)), - birthday: birthday.clone(), - purpose: AccountPurpose::Spending, - notes: HashSet::new(), - }; + ViewingKey::Full(Box::new(ufvk)), + birthday.clone(), + AccountPurpose::Spending, + )?; + let id = account.id(); self.accounts.push(account); @@ -91,10 +92,13 @@ impl WalletWrite for MemoryWalletDb { fn get_next_available_address( &mut self, - _account: Self::AccountId, - _request: UnifiedAddressRequest, + account: Self::AccountId, + request: UnifiedAddressRequest, ) -> Result, Self::Error> { - todo!() + self.get_account_mut(account) + .map(|account| account.next_available_address(request)) + .transpose() + .map(|a| a.flatten()) } fn update_chain_tip(&mut self, _tip_height: BlockHeight) -> Result<(), Self::Error> { @@ -277,17 +281,17 @@ impl WalletWrite for MemoryWalletDb { .map_err(|_| "key derivation error".to_string()) .unwrap(); let ufvk = usk.to_unified_full_viewing_key(); - let account = Account { - account_id: AccountId(self.accounts.len() as u32), - kind: AccountSource::Derived { + + let account = Account::new( + AccountId(self.accounts.len() as u32), + AccountSource::Derived { seed_fingerprint, account_index, }, - viewing_key: ViewingKey::Full(Box::new(ufvk)), - birthday: birthday.clone(), - purpose: AccountPurpose::Spending, - notes: HashSet::new(), - }; + ViewingKey::Full(Box::new(ufvk)), + birthday.clone(), + AccountPurpose::Spending, + )?; // TODO: Do we need to check if duplicate? self.accounts.push(account.clone()); Ok((account, usk)) @@ -299,14 +303,14 @@ impl WalletWrite for MemoryWalletDb { birthday: &AccountBirthday, purpose: AccountPurpose, ) -> Result { - let account = Account { - account_id: AccountId(self.accounts.len() as u32), - kind: AccountSource::Imported { purpose }, - viewing_key: ViewingKey::Full(Box::new(unified_key.to_owned())), - birthday: birthday.clone(), + let account = Account::new( + AccountId(self.accounts.len() as u32), + AccountSource::Imported { purpose }, + ViewingKey::Full(Box::new(unified_key.to_owned())), + birthday.clone(), purpose, - notes: HashSet::new(), - }; + )?; + self.accounts.push(account.clone()); Ok(account) }