Skip to content

Commit

Permalink
Merge pull request #6 from ChainSafe/ec2/fix-account-management
Browse files Browse the repository at this point in the history
Ec2/fix account management
  • Loading branch information
ec2 authored Aug 27, 2024
2 parents fd077c1 + 5ec0056 commit ce936d3
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 52 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,6 @@ zip32 = "0.1.1"
lto = true
panic = 'abort'
codegen-units = 1

[patch.crates-io]
zip32 = { git = "https://github.com/zcash/zip32.git", branch = "diversifier_index_ord"}
4 changes: 4 additions & 0 deletions zcash_client_memory/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub enum Error {
MemoDecryption(memo::Error),
#[error("Error deriving key: {0}")]
KeyDerivation(DerivationError),
#[error("Unknown ZIP32 derivation ")]
UnknownZip32Derivation,
#[error("Error generating address: {0}")]
AddressGeneration(Type),
#[error("Seed must be between 32 and 252 bytes in length.")]
Expand All @@ -32,6 +34,8 @@ pub enum Error {
IoError(std::io::Error),
#[error("Corrupted Data: {0}")]
CorruptedData(String),
#[error("An error occurred while processing an account due to a failure in deriving the account's keys: {0}")]
BadAccountData(String),
#[error("Other error: {0}")]
Other(String),
}
Expand Down
77 changes: 76 additions & 1 deletion zcash_client_memory/src/mem_wallet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ impl MemoryWalletDb {
Ok(())
}

// fn get_account(&self, account_id: AccountId) -> Option<&Account> {
// self.accounts.get(*account_id as usize)
// }
fn get_account_mut(&mut self, account_id: AccountId) -> Option<&mut Account> {
self.accounts.get_mut(*account_id as usize)
}

#[cfg(feature = "orchard")]
fn mark_orchard_note_spent(
&mut self,
Expand Down Expand Up @@ -280,11 +287,41 @@ pub struct Account {
kind: AccountSource,
viewing_key: ViewingKey,
birthday: AccountBirthday,
purpose: AccountPurpose,
purpose: AccountPurpose, // TODO: Remove this. AccountSource should be sufficient.
addresses: BTreeMap<DiversifierIndex, UnifiedAddress>,
notes: HashSet<NoteId>,
}

impl Account {
fn new(
account_id: AccountId,
kind: AccountSource,
viewing_key: ViewingKey,
birthday: AccountBirthday,
purpose: AccountPurpose,
) -> Result<Self, Error> {
let mut acc = Self {
account_id,
kind,
viewing_key,
birthday,
purpose,
addresses: BTreeMap::new(),
notes: HashSet::new(),
};
let ua_request = acc
.viewing_key
.uivk()
.to_address_request()
.and_then(|ua_request| ua_request.intersect(&UnifiedAddressRequest::all().unwrap()))
.ok_or_else(|| {
Error::AddressGeneration(AddressGenerationError::ShieldedReceiverRequired)
})?;

let (addr, diversifier_index) = acc.default_address(ua_request)?;
acc.addresses.insert(diversifier_index, addr);
Ok(acc)
}
/// Returns the default Unified Address for the account,
/// along with the diversifier index that generated it.
///
Expand All @@ -300,6 +337,44 @@ impl Account {
fn birthday(&self) -> &AccountBirthday {
&self.birthday
}

fn addresses(&self) -> &BTreeMap<DiversifierIndex, UnifiedAddress> {
&self.addresses
}

fn current_address(&self) -> Option<(DiversifierIndex, UnifiedAddress)> {
self.addresses
.last_key_value()
.map(|(diversifier_index, address)| (*diversifier_index, address.clone()))
}
fn kind(&self) -> &AccountSource {
&self.kind
}
fn viewing_key(&self) -> &ViewingKey {
&self.viewing_key
}
fn next_available_address(
&mut self,
request: UnifiedAddressRequest,
) -> Result<Option<UnifiedAddress>, Error> {
match self.ufvk() {
Some(ufvk) => {
let search_from = match self.current_address() {
Some((mut last_diversifier_index, _)) => {
last_diversifier_index
.increment()
.map_err(|_| AddressGenerationError::DiversifierSpaceExhausted)?;
last_diversifier_index
}
None => DiversifierIndex::default(),
};
let (addr, diversifier_index) = ufvk.find_address(search_from, request)?;
self.addresses.insert(diversifier_index, addr.clone());
Ok(Some(addr))
}
None => Ok(None),
}
}
}

impl zcash_client_backend::data_api::Account<AccountId> for Account {
Expand Down
152 changes: 129 additions & 23 deletions zcash_client_memory/src/mem_wallet/wallet_read.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use incrementalmerkletree::{Address, Marking, Retention};
use nonempty::NonEmpty;
use sapling::NullifierDerivingKey;
use secrecy::{ExposeSecret, SecretVec};
use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree};
use std::{
clone,
cmp::Ordering,
collections::{BTreeMap, HashMap, HashSet},
convert::Infallible,
Expand All @@ -28,7 +30,7 @@ use zcash_primitives::{
transaction::{Transaction, TransactionData, TxId},
};
use zcash_protocol::{
consensus::BranchId,
consensus::{self, BranchId},
memo::{self, Memo, MemoBytes},
value::Zatoshis,
ShieldedProtocol::{Orchard, Sapling},
Expand Down Expand Up @@ -56,7 +58,7 @@ impl WalletRead for MemoryWalletDb {
type Account = Account;

fn get_account_ids(&self) -> Result<Vec<Self::AccountId>, Self::Error> {
Ok(Vec::new())
Ok(self.accounts.iter().map(|a| a.id()).collect())
}

fn get_account(
Expand All @@ -68,25 +70,98 @@ impl WalletRead for MemoryWalletDb {

fn get_derived_account(
&self,
_seed: &SeedFingerprint,
_account_id: zip32::AccountId,
seed: &SeedFingerprint,
account_id: zip32::AccountId,
) -> Result<Option<Self::Account>, Self::Error> {
todo!()
Ok(self.accounts.iter().find_map(|acct| match acct.kind() {
AccountSource::Derived {
seed_fingerprint,
account_index,
} => {
if seed_fingerprint == seed && account_index == &account_id {
Some(acct.clone())
} else {
None
}
}
AccountSource::Imported { purpose } => None,
}))
}

fn validate_seed(
&self,
_account_id: Self::AccountId,
_seed: &SecretVec<u8>,
account_id: Self::AccountId,
seed: &SecretVec<u8>,
) -> Result<bool, Self::Error> {
todo!()
if let Some(account) = self.get_account(account_id)? {
if let AccountSource::Derived {
seed_fingerprint,
account_index,
} = account.source()
{
seed_matches_derived_account(
&self.network,
seed,
&seed_fingerprint,
account_index,
&account.uivk(),
)
} else {
Err(Error::UnknownZip32Derivation)
}
} else {
// Missing account is documented to return false.
Ok(false)
}
}

fn seed_relevance_to_derived_accounts(
&self,
seed: &SecretVec<u8>,
) -> Result<SeedRelevance<Self::AccountId>, Self::Error> {
todo!()
let mut has_accounts = false;
let mut has_derived = false;
let mut relevant_account_ids = vec![];

for account_id in self.get_account_ids()? {
has_accounts = true;
let account = self.get_account(account_id)?.expect("account ID exists");

// If the account is imported, the seed _might_ be relevant, but the only
// way we could determine that is by brute-forcing the ZIP 32 account
// index space, which we're not going to do. The method name indicates to
// the caller that we only check derived accounts.
if let AccountSource::Derived {
seed_fingerprint,
account_index,
} = account.source()
{
has_derived = true;

if seed_matches_derived_account(
&self.network,
seed,
&seed_fingerprint,
account_index,
&account.uivk(),
)? {
// The seed is relevant to this account.
relevant_account_ids.push(account_id);
}
}
}

Ok(
if let Some(account_ids) = NonEmpty::from_vec(relevant_account_ids) {
SeedRelevance::Relevant { account_ids }
} else if has_derived {
SeedRelevance::NotRelevant
} else if has_accounts {
SeedRelevance::NoDerivedAccounts
} else {
SeedRelevance::NoAccounts
},
)
}

fn get_account_for_ufvk(
Expand All @@ -110,20 +185,10 @@ impl WalletRead for MemoryWalletDb {
&self,
account: Self::AccountId,
) -> Result<Option<UnifiedAddress>, Self::Error> {
self.accounts
.get(*account as usize)
.map(|account| {
account
.ufvk()
.unwrap()
.default_address(
UnifiedAddressRequest::all()
.expect("At least one protocol should be enabled."),
)
.map(|(addr, _)| addr)
})
.transpose()
.map_err(|e| e.into())
Ok(self
.get_account(account)?
.and_then(|account| Account::current_address(&account))
.map(|(_, a)| a.clone()))
}

fn get_account_birthday(&self, account: Self::AccountId) -> Result<BlockHeight, Self::Error> {
Expand Down Expand Up @@ -350,3 +415,44 @@ impl WalletRead for MemoryWalletDb {
todo!()
}
}

/// Copied from zcash_client_sqlite::wallet::seed_matches_derived_account
fn seed_matches_derived_account<P: consensus::Parameters>(
params: &P,
seed: &SecretVec<u8>,
seed_fingerprint: &SeedFingerprint,
account_index: zip32::AccountId,
uivk: &UnifiedIncomingViewingKey,
) -> Result<bool, Error> {
let seed_fingerprint_match =
&SeedFingerprint::from_seed(seed.expose_secret()).ok_or_else(|| {
Error::BadAccountData("Seed must be between 32 and 252 bytes in length.".to_owned())
})? == seed_fingerprint;

// Keys are not comparable with `Eq`, but addresses are, so we derive what should
// be equivalent addresses for each key and use those to check for key equality.
let uivk_match =
match UnifiedSpendingKey::from_seed(params, &seed.expose_secret()[..], account_index) {
// If we can't derive a USK from the given seed with the account's ZIP 32
// account index, then we immediately know the UIVK won't match because wallet
// accounts are required to have a known UIVK.
Err(_) => false,
Ok(usk) => {
UnifiedAddressRequest::all().map_or(Ok::<_, Error>(false), |ua_request| {
Ok(usk
.to_unified_full_viewing_key()
.default_address(ua_request)?
== uivk.default_address(ua_request)?)
})?
}
};

if seed_fingerprint_match != uivk_match {
// If these mismatch, it suggests database corruption.
Err(Error::CorruptedData(format!(
"Seed fingerprint match: {seed_fingerprint_match}, uivk match: {uivk_match}"
)))
} else {
Ok(seed_fingerprint_match && uivk_match)
}
}
Loading

0 comments on commit ce936d3

Please sign in to comment.