From 700e742f2eff4b397eae72f6a0c036eb78aaedf6 Mon Sep 17 00:00:00 2001 From: Eric Tu Date: Tue, 27 Aug 2024 14:14:28 -0400 Subject: [PATCH 1/4] implemented get_next_available_address --- zcash_client_memory/src/mem_wallet/mod.rs | 72 ++++++++++++++++++- .../src/mem_wallet/wallet_read.rs | 19 ++--- .../src/mem_wallet/wallet_write.rs | 56 ++++++++------- 3 files changed, 106 insertions(+), 41 deletions(-) diff --git a/zcash_client_memory/src/mem_wallet/mod.rs b/zcash_client_memory/src/mem_wallet/mod.rs index 68f8d79e19..918e14b764 100644 --- a/zcash_client_memory/src/mem_wallet/mod.rs +++ b/zcash_client_memory/src/mem_wallet/mod.rs @@ -124,6 +124,13 @@ impl MemoryWalletDb { Ok(()) } + fn get_account(&self, account_id: AccountId) -> Option<&Account> { + self.accounts.get(*account_id as usize) + } + fn get_account_mut(&mut self, account_id: AccountId) -> Option<&mut Account> { + self.accounts.get_mut(*account_id as usize) + } + #[cfg(feature = "orchard")] fn mark_orchard_note_spent( &mut self, @@ -280,11 +287,41 @@ pub struct Account { kind: AccountSource, viewing_key: ViewingKey, birthday: AccountBirthday, - purpose: AccountPurpose, + purpose: AccountPurpose, // TODO: Remove this. AccountSource should be sufficient. + addresses: BTreeMap, notes: HashSet, } impl Account { + fn new( + account_id: AccountId, + kind: AccountSource, + viewing_key: ViewingKey, + birthday: AccountBirthday, + purpose: AccountPurpose, + ) -> Result { + let mut acc = Self { + account_id, + kind, + viewing_key, + birthday, + purpose, + addresses: BTreeMap::new(), + notes: HashSet::new(), + }; + let ua_request = acc + .viewing_key + .uivk() + .to_address_request() + .and_then(|ua_request| ua_request.intersect(&UnifiedAddressRequest::all().unwrap())) + .ok_or_else(|| { + Error::AddressGeneration(AddressGenerationError::ShieldedReceiverRequired) + })?; + + let (addr, diversifier_index) = acc.default_address(ua_request)?; + acc.addresses.insert(diversifier_index, addr); + Ok(acc) + } /// Returns the default Unified Address for the account, /// along with the diversifier index that generated it. /// @@ -300,6 +337,39 @@ impl Account { fn birthday(&self) -> &AccountBirthday { &self.birthday } + + fn addresses(&self) -> &BTreeMap { + &self.addresses + } + + fn current_address(&self) -> Option<(DiversifierIndex, UnifiedAddress)> { + self.addresses + .last_key_value() + .map(|(diversifier_index, address)| (*diversifier_index, address.clone())) + } + + fn next_available_address( + &mut self, + request: UnifiedAddressRequest, + ) -> Result, Error> { + match self.ufvk() { + Some(ufvk) => { + let search_from = match self.current_address() { + Some((mut last_diversifier_index, _)) => { + last_diversifier_index + .increment() + .map_err(|_| AddressGenerationError::DiversifierSpaceExhausted)?; + last_diversifier_index + } + None => DiversifierIndex::default(), + }; + let (addr, diversifier_index) = ufvk.find_address(search_from, request)?; + self.addresses.insert(diversifier_index, addr.clone()); + Ok(Some(addr)) + } + None => Ok(None), + } + } } impl zcash_client_backend::data_api::Account for Account { diff --git a/zcash_client_memory/src/mem_wallet/wallet_read.rs b/zcash_client_memory/src/mem_wallet/wallet_read.rs index 47e332b5dd..80a5c2dd67 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_read.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_read.rs @@ -3,6 +3,7 @@ use sapling::NullifierDerivingKey; use secrecy::{ExposeSecret, SecretVec}; use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree}; use std::{ + clone, cmp::Ordering, collections::{BTreeMap, HashMap, HashSet}, convert::Infallible, @@ -110,20 +111,10 @@ impl WalletRead for MemoryWalletDb { &self, account: Self::AccountId, ) -> Result, Self::Error> { - self.accounts - .get(*account as usize) - .map(|account| { - account - .ufvk() - .unwrap() - .default_address( - UnifiedAddressRequest::all() - .expect("At least one protocol should be enabled."), - ) - .map(|(addr, _)| addr) - }) - .transpose() - .map_err(|e| e.into()) + Ok(self + .get_account(account) + .and_then(Account::current_address) + .map(|(_, a)| a.clone())) } fn get_account_birthday(&self, account: Self::AccountId) -> Result { diff --git a/zcash_client_memory/src/mem_wallet/wallet_write.rs b/zcash_client_memory/src/mem_wallet/wallet_write.rs index 1a01ccfc00..ae195904d7 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_write.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_write.rs @@ -72,17 +72,18 @@ impl WalletWrite for MemoryWalletDb { let usk = UnifiedSpendingKey::from_seed(&self.network, seed.expose_secret(), account_index)?; let ufvk = usk.to_unified_full_viewing_key(); - let account = Account { - account_id: AccountId(self.accounts.len() as u32), - kind: AccountSource::Derived { + + let account = Account::new( + AccountId(self.accounts.len() as u32), + AccountSource::Derived { seed_fingerprint, account_index, }, - viewing_key: ViewingKey::Full(Box::new(ufvk)), - birthday: birthday.clone(), - purpose: AccountPurpose::Spending, - notes: HashSet::new(), - }; + ViewingKey::Full(Box::new(ufvk)), + birthday.clone(), + AccountPurpose::Spending, + )?; + let id = account.id(); self.accounts.push(account); @@ -91,10 +92,13 @@ impl WalletWrite for MemoryWalletDb { fn get_next_available_address( &mut self, - _account: Self::AccountId, - _request: UnifiedAddressRequest, + account: Self::AccountId, + request: UnifiedAddressRequest, ) -> Result, Self::Error> { - todo!() + self.get_account_mut(account) + .map(|account| account.next_available_address(request)) + .transpose() + .map(|a| a.flatten()) } fn update_chain_tip(&mut self, _tip_height: BlockHeight) -> Result<(), Self::Error> { @@ -277,17 +281,17 @@ impl WalletWrite for MemoryWalletDb { .map_err(|_| "key derivation error".to_string()) .unwrap(); let ufvk = usk.to_unified_full_viewing_key(); - let account = Account { - account_id: AccountId(self.accounts.len() as u32), - kind: AccountSource::Derived { + + let account = Account::new( + AccountId(self.accounts.len() as u32), + AccountSource::Derived { seed_fingerprint, account_index, }, - viewing_key: ViewingKey::Full(Box::new(ufvk)), - birthday: birthday.clone(), - purpose: AccountPurpose::Spending, - notes: HashSet::new(), - }; + ViewingKey::Full(Box::new(ufvk)), + birthday.clone(), + AccountPurpose::Spending, + )?; // TODO: Do we need to check if duplicate? self.accounts.push(account.clone()); Ok((account, usk)) @@ -299,14 +303,14 @@ impl WalletWrite for MemoryWalletDb { birthday: &AccountBirthday, purpose: AccountPurpose, ) -> Result { - let account = Account { - account_id: AccountId(self.accounts.len() as u32), - kind: AccountSource::Imported { purpose }, - viewing_key: ViewingKey::Full(Box::new(unified_key.to_owned())), - birthday: birthday.clone(), + let account = Account::new( + AccountId(self.accounts.len() as u32), + AccountSource::Imported { purpose }, + ViewingKey::Full(Box::new(unified_key.to_owned())), + birthday.clone(), purpose, - notes: HashSet::new(), - }; + )?; + self.accounts.push(account.clone()); Ok(account) } From 8243ca7a2d9c53d97afb84db0194dfb42d9bb0af Mon Sep 17 00:00:00 2001 From: Eric Tu Date: Tue, 27 Aug 2024 14:17:24 -0400 Subject: [PATCH 2/4] cargo --- Cargo.lock | 3 +-- Cargo.toml | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5cd0b68ad0..50d1c5cb6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6158,8 +6158,7 @@ dependencies = [ [[package]] name = "zip32" version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4226d0aee9c9407c27064dfeec9d7b281c917de3374e1e5a2e2cfad9e09de19e" +source = "git+https://github.com/zcash/zip32.git?branch=diversifier_index_ord#38e39b7086bbd5747dc61a84faf54ec9a58fa535" dependencies = [ "blake2b_simd", "memuse", diff --git a/Cargo.toml b/Cargo.toml index b19975e41c..6386a17463 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -157,3 +157,6 @@ zip32 = "0.1.1" lto = true panic = 'abort' codegen-units = 1 + +[patch.crates-io] +zip32 = { git = "https://github.com/zcash/zip32.git", branch = "diversifier_index_ord"} \ No newline at end of file From 9cb0eac1389495ae8a5c0be0e4261d89472ae1f8 Mon Sep 17 00:00:00 2001 From: Eric Tu Date: Tue, 27 Aug 2024 14:31:28 -0400 Subject: [PATCH 3/4] filled in more account related methods --- zcash_client_memory/src/mem_wallet/mod.rs | 7 ++++++- .../src/mem_wallet/wallet_read.rs | 20 +++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/zcash_client_memory/src/mem_wallet/mod.rs b/zcash_client_memory/src/mem_wallet/mod.rs index 918e14b764..594f869fc8 100644 --- a/zcash_client_memory/src/mem_wallet/mod.rs +++ b/zcash_client_memory/src/mem_wallet/mod.rs @@ -347,7 +347,12 @@ impl Account { .last_key_value() .map(|(diversifier_index, address)| (*diversifier_index, address.clone())) } - + fn kind(&self) -> &AccountSource { + &self.kind + } + fn viewing_key(&self) -> &ViewingKey { + &self.viewing_key + } fn next_available_address( &mut self, request: UnifiedAddressRequest, diff --git a/zcash_client_memory/src/mem_wallet/wallet_read.rs b/zcash_client_memory/src/mem_wallet/wallet_read.rs index 80a5c2dd67..8338edbf6c 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_read.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_read.rs @@ -57,7 +57,7 @@ impl WalletRead for MemoryWalletDb { type Account = Account; fn get_account_ids(&self) -> Result, Self::Error> { - Ok(Vec::new()) + Ok(self.accounts.iter().map(|a| a.id()).collect()) } fn get_account( @@ -69,10 +69,22 @@ impl WalletRead for MemoryWalletDb { fn get_derived_account( &self, - _seed: &SeedFingerprint, - _account_id: zip32::AccountId, + seed: &SeedFingerprint, + account_id: zip32::AccountId, ) -> Result, Self::Error> { - todo!() + Ok(self.accounts.iter().find_map(|acct| match acct.kind() { + AccountSource::Derived { + seed_fingerprint, + account_index, + } => { + if seed_fingerprint == seed && account_index == &account_id { + Some(acct.clone()) + } else { + None + } + } + AccountSource::Imported { purpose } => None, + })) } fn validate_seed( From 5cb9a1d618047c9d6d712803b417fd6b001a3750 Mon Sep 17 00:00:00 2001 From: Eric Tu Date: Tue, 27 Aug 2024 14:42:13 -0400 Subject: [PATCH 4/4] more account method impls --- zcash_client_memory/src/error.rs | 4 + zcash_client_memory/src/mem_wallet/mod.rs | 6 +- .../src/mem_wallet/wallet_read.rs | 117 ++++++++++++++++-- 3 files changed, 117 insertions(+), 10 deletions(-) diff --git a/zcash_client_memory/src/error.rs b/zcash_client_memory/src/error.rs index c65910b116..8548404a56 100644 --- a/zcash_client_memory/src/error.rs +++ b/zcash_client_memory/src/error.rs @@ -16,6 +16,8 @@ pub enum Error { MemoDecryption(memo::Error), #[error("Error deriving key: {0}")] KeyDerivation(DerivationError), + #[error("Unknown ZIP32 derivation ")] + UnknownZip32Derivation, #[error("Error generating address: {0}")] AddressGeneration(Type), #[error("Seed must be between 32 and 252 bytes in length.")] @@ -32,6 +34,8 @@ pub enum Error { IoError(std::io::Error), #[error("Corrupted Data: {0}")] CorruptedData(String), + #[error("An error occurred while processing an account due to a failure in deriving the account's keys: {0}")] + BadAccountData(String), #[error("Other error: {0}")] Other(String), } diff --git a/zcash_client_memory/src/mem_wallet/mod.rs b/zcash_client_memory/src/mem_wallet/mod.rs index 594f869fc8..20b8f44dc5 100644 --- a/zcash_client_memory/src/mem_wallet/mod.rs +++ b/zcash_client_memory/src/mem_wallet/mod.rs @@ -124,9 +124,9 @@ impl MemoryWalletDb { Ok(()) } - fn get_account(&self, account_id: AccountId) -> Option<&Account> { - self.accounts.get(*account_id as usize) - } + // fn get_account(&self, account_id: AccountId) -> Option<&Account> { + // self.accounts.get(*account_id as usize) + // } fn get_account_mut(&mut self, account_id: AccountId) -> Option<&mut Account> { self.accounts.get_mut(*account_id as usize) } diff --git a/zcash_client_memory/src/mem_wallet/wallet_read.rs b/zcash_client_memory/src/mem_wallet/wallet_read.rs index 8338edbf6c..8b82987696 100644 --- a/zcash_client_memory/src/mem_wallet/wallet_read.rs +++ b/zcash_client_memory/src/mem_wallet/wallet_read.rs @@ -1,4 +1,5 @@ use incrementalmerkletree::{Address, Marking, Retention}; +use nonempty::NonEmpty; use sapling::NullifierDerivingKey; use secrecy::{ExposeSecret, SecretVec}; use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree}; @@ -29,7 +30,7 @@ use zcash_primitives::{ transaction::{Transaction, TransactionData, TxId}, }; use zcash_protocol::{ - consensus::BranchId, + consensus::{self, BranchId}, memo::{self, Memo, MemoBytes}, value::Zatoshis, ShieldedProtocol::{Orchard, Sapling}, @@ -89,17 +90,78 @@ impl WalletRead for MemoryWalletDb { fn validate_seed( &self, - _account_id: Self::AccountId, - _seed: &SecretVec, + account_id: Self::AccountId, + seed: &SecretVec, ) -> Result { - todo!() + if let Some(account) = self.get_account(account_id)? { + if let AccountSource::Derived { + seed_fingerprint, + account_index, + } = account.source() + { + seed_matches_derived_account( + &self.network, + seed, + &seed_fingerprint, + account_index, + &account.uivk(), + ) + } else { + Err(Error::UnknownZip32Derivation) + } + } else { + // Missing account is documented to return false. + Ok(false) + } } fn seed_relevance_to_derived_accounts( &self, seed: &SecretVec, ) -> Result, Self::Error> { - todo!() + let mut has_accounts = false; + let mut has_derived = false; + let mut relevant_account_ids = vec![]; + + for account_id in self.get_account_ids()? { + has_accounts = true; + let account = self.get_account(account_id)?.expect("account ID exists"); + + // If the account is imported, the seed _might_ be relevant, but the only + // way we could determine that is by brute-forcing the ZIP 32 account + // index space, which we're not going to do. The method name indicates to + // the caller that we only check derived accounts. + if let AccountSource::Derived { + seed_fingerprint, + account_index, + } = account.source() + { + has_derived = true; + + if seed_matches_derived_account( + &self.network, + seed, + &seed_fingerprint, + account_index, + &account.uivk(), + )? { + // The seed is relevant to this account. + relevant_account_ids.push(account_id); + } + } + } + + Ok( + if let Some(account_ids) = NonEmpty::from_vec(relevant_account_ids) { + SeedRelevance::Relevant { account_ids } + } else if has_derived { + SeedRelevance::NotRelevant + } else if has_accounts { + SeedRelevance::NoDerivedAccounts + } else { + SeedRelevance::NoAccounts + }, + ) } fn get_account_for_ufvk( @@ -124,8 +186,8 @@ impl WalletRead for MemoryWalletDb { account: Self::AccountId, ) -> Result, Self::Error> { Ok(self - .get_account(account) - .and_then(Account::current_address) + .get_account(account)? + .and_then(|account| Account::current_address(&account)) .map(|(_, a)| a.clone())) } @@ -353,3 +415,44 @@ impl WalletRead for MemoryWalletDb { todo!() } } + +/// Copied from zcash_client_sqlite::wallet::seed_matches_derived_account +fn seed_matches_derived_account( + params: &P, + seed: &SecretVec, + seed_fingerprint: &SeedFingerprint, + account_index: zip32::AccountId, + uivk: &UnifiedIncomingViewingKey, +) -> Result { + let seed_fingerprint_match = + &SeedFingerprint::from_seed(seed.expose_secret()).ok_or_else(|| { + Error::BadAccountData("Seed must be between 32 and 252 bytes in length.".to_owned()) + })? == seed_fingerprint; + + // Keys are not comparable with `Eq`, but addresses are, so we derive what should + // be equivalent addresses for each key and use those to check for key equality. + let uivk_match = + match UnifiedSpendingKey::from_seed(params, &seed.expose_secret()[..], account_index) { + // If we can't derive a USK from the given seed with the account's ZIP 32 + // account index, then we immediately know the UIVK won't match because wallet + // accounts are required to have a known UIVK. + Err(_) => false, + Ok(usk) => { + UnifiedAddressRequest::all().map_or(Ok::<_, Error>(false), |ua_request| { + Ok(usk + .to_unified_full_viewing_key() + .default_address(ua_request)? + == uivk.default_address(ua_request)?) + })? + } + }; + + if seed_fingerprint_match != uivk_match { + // If these mismatch, it suggests database corruption. + Err(Error::CorruptedData(format!( + "Seed fingerprint match: {seed_fingerprint_match}, uivk match: {uivk_match}" + ))) + } else { + Ok(seed_fingerprint_match && uivk_match) + } +}