Skip to content

Commit

Permalink
type safe tree and frontier serde
Browse files Browse the repository at this point in the history
  • Loading branch information
ec2 committed Sep 9, 2024
1 parent fa812b0 commit 9dd40df
Show file tree
Hide file tree
Showing 4 changed files with 383 additions and 393 deletions.
104 changes: 61 additions & 43 deletions zcash_client_memory/src/types/serialization/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Display;
use std::io;

use std::marker::PhantomData;
use std::sync::Arc;

use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -53,57 +55,73 @@ pub trait ToFromBytes {
Self: Sized;
}

impl<T: ToFromBytes> ToFromBytes for Arc<T> {
fn to_bytes(&self) -> Vec<u8> {
self.as_ref().to_bytes()
pub trait ToArray<T, const N: usize> {
fn to_arr(&self) -> [T; N];
}
impl<T: ToArray<U, N>, U, const N: usize> ToArray<U, N> for Arc<T> {
fn to_arr(&self) -> [U; N] {
self.as_ref().to_arr()
}

fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
T::from_bytes(bytes).map(Arc::new)
}
impl<T: TryFromArray<U, N>, U, const N: usize> TryFromArray<U, N> for Arc<T> {
type Error = T::Error;
fn from_arr(arr: [U; N]) -> Result<Self, Self::Error> {
Ok(Arc::new(T::from_arr(arr)?))
}
}

#[serde_as]
pub struct ToFromBytesWrapper<T: ToFromBytes>(T);
pub trait TryFromArray<T, const N: usize>
where
Self: Sized,
{
type Error: Display;
fn from_arr(arr: [T; N]) -> Result<Self, Self::Error>;
}
pub use bytes::*;

impl<T: ToFromBytes> SerializeAs<T> for ToFromBytesWrapper<T> {
fn serialize_as<S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
value.to_bytes().serialize(serializer)
mod bytes {
use super::*;
pub struct ToFromBytesWrapper<T: ToFromBytes>(T);

impl<T: ToFromBytes> SerializeAs<T> for ToFromBytesWrapper<T> {
fn serialize_as<S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
value.to_bytes().serialize(serializer)
}
}
}
impl<T: ToFromBytes> SerializeAs<&T> for ToFromBytesWrapper<T> {
fn serialize_as<S>(value: &&T, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
value.to_bytes().serialize(serializer)
impl<T: ToFromBytes> SerializeAs<&T> for ToFromBytesWrapper<T> {
fn serialize_as<S>(value: &&T, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
value.to_bytes().serialize(serializer)
}
}
}
impl<'de, T: ToFromBytes> DeserializeAs<'de, T> for ToFromBytesWrapper<T> {
fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
where
D: serde::Deserializer<'de>,
{
T::from_bytes(<Vec<u8>>::deserialize(deserializer)?.as_slice())
.map_err(serde::de::Error::custom)
impl<'de, T: ToFromBytes> DeserializeAs<'de, T> for ToFromBytesWrapper<T> {
fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
where
D: serde::Deserializer<'de>,
{
T::from_bytes(<Vec<u8>>::deserialize(deserializer)?.as_slice())
.map_err(serde::de::Error::custom)
}
}
}
impl<T: ToFromBytes> Serialize for ToFromBytesWrapper<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
ToFromBytesWrapper::<T>::serialize_as(&self.0, serializer)
impl<T: ToFromBytes> Serialize for ToFromBytesWrapper<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
ToFromBytesWrapper::<T>::serialize_as(&self.0, serializer)
}
}
}
impl<'de, T: ToFromBytes> Deserialize<'de> for ToFromBytesWrapper<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
ToFromBytesWrapper::<T>::deserialize_as(deserializer).map(ToFromBytesWrapper)
impl<'de, T: ToFromBytes> Deserialize<'de> for ToFromBytesWrapper<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
ToFromBytesWrapper::<T>::deserialize_as(deserializer).map(ToFromBytesWrapper)
}
}
}
221 changes: 85 additions & 136 deletions zcash_client_memory/src/types/serialization/shardtree/frontier.rs
Original file line number Diff line number Diff line change
@@ -1,174 +1,123 @@
use std::fmt::Display;

use incrementalmerkletree::frontier::{self, Frontier, NonEmptyFrontier};
use incrementalmerkletree::Position;
use serde::ser::SerializeStruct;

use serde_with::{de::DeserializeAs, de::DeserializeAsWrap, ser::SerializeAsWrap};
use serde_with::{FromInto, SerializeAs};
use serde::{Deserialize, Serialize};
use serde_with::SerializeAs;
use serde_with::{de::DeserializeAs, serde_as};

use crate::{ToFromBytes, ToFromBytesWrapper};
use crate::{ToArray, TryFromArray};

pub struct FrontierWrapper;
impl<T: ToFromBytes + Clone, const DEPTH: u8> SerializeAs<Frontier<T, DEPTH>> for FrontierWrapper {
fn serialize_as<S>(value: &Frontier<T, DEPTH>, serializer: S) -> Result<S::Ok, S::Error>

impl<H: ToArray<u8, 32> + Clone, const DEPTH: u8> SerializeAs<Frontier<H, DEPTH>>
for FrontierWrapper
{
fn serialize_as<S>(value: &Frontier<H, DEPTH>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut s = serializer.serialize_struct("Frontier", 1)?;
s.serialize_field(
"frontier",
&SerializeAsWrap::<_, Option<NonEmptyFrontierWrapper>>::new(&value.value().cloned()),
)?;
s.end()
#[serde_as]
#[derive(Serialize)]
struct FrontierSer<'a, H: ToArray<u8, 32>> {
#[serde_as(as = "Option<&'a NonEmptyFrontierWrapper>")]
frontier: &'a Option<&'a NonEmptyFrontier<H>>,
}

FrontierSer {
frontier: &value.value(),
}
.serialize(serializer)
}
}
impl<'de, T: ToFromBytes + Clone, const DEPTH: u8> DeserializeAs<'de, Frontier<T, DEPTH>>
for FrontierWrapper
impl<'de, H: TryFromArray<u8, 32, Error = E>, E: Display, const DEPTH: u8>
DeserializeAs<'de, Frontier<H, DEPTH>> for FrontierWrapper
{
fn deserialize_as<D>(deserializer: D) -> Result<Frontier<T, DEPTH>, D::Error>
fn deserialize_as<D>(deserializer: D) -> Result<Frontier<H, DEPTH>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor<T, const DEPTH: u8>(std::marker::PhantomData<T>);
impl<T, const DEPTH: u8> Visitor<T, DEPTH> {
fn new() -> Self {
Self(std::marker::PhantomData)
}
}
impl<'de, T: ToFromBytes + Clone, const DEPTH: u8> serde::de::Visitor<'de> for Visitor<T, DEPTH> {
type Value = Frontier<T, DEPTH>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct Frontier")
}
fn visit_map<A>(self, mut map: A) -> Result<Frontier<T, DEPTH>, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut frontier = None;
while let Some(key) = map.next_key()? {
match key {
"frontier" => {
frontier = map
.next_value::<Option<
DeserializeAsWrap<NonEmptyFrontier<T>, NonEmptyFrontierWrapper>,
>>()?
.map(|f| f.into_inner());
}
_ => {
return Err(serde::de::Error::unknown_field(key, &["frontier"]));
}
}
}
frontier
.map(NonEmptyFrontier::into_parts)
.map(|(p, l, o)| {
frontier::Frontier::from_parts(p, l, o).map_err(|_e| {
serde::de::Error::custom("failed to construct frontier from parts")
})
})
.transpose()?
.ok_or_else(|| serde::de::Error::missing_field("frontier"))
}
#[derive(Deserialize)]
struct FrontierDe {
frontier: Option<NonEmptyFrontierDe>,
}
deserializer.deserialize_struct("Frontier", &["frontier"], Visitor::<T, DEPTH>::new())
let frontier = FrontierDe::deserialize(deserializer)?;
frontier
.frontier
.map(|f| {
let p = Position::from(f.position);
let l = H::from_arr(f.leaf).map_err(serde::de::Error::custom)?;
let o = f
.ommers
.into_iter()
.map(|o| H::from_arr(o).map_err(serde::de::Error::custom))
.collect::<Result<Vec<_>, _>>()?;
frontier::Frontier::from_parts(p, l, o).map_err(|_e| {
serde::de::Error::custom("failed to construct frontier from parts")
})
})
.transpose()?
.ok_or_else(|| serde::de::Error::missing_field("frontier"))
}
}

pub struct NonEmptyFrontierWrapper;

impl<T: ToFromBytes> SerializeAs<NonEmptyFrontier<T>> for NonEmptyFrontierWrapper {
impl<T> SerializeAs<NonEmptyFrontier<T>> for NonEmptyFrontierWrapper
where
T: ToArray<u8, 32>,
{
fn serialize_as<S>(value: &NonEmptyFrontier<T>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let ommers = value
#[derive(Serialize)]
struct NonEmptyFrontierSer<'a> {
pub position: u64,
pub leaf: &'a [u8; 32],
pub ommers: &'a [[u8; 32]],
}

let ommer = value
.ommers()
.iter()
.map(|o| SerializeAsWrap::<_, ToFromBytesWrapper<T>>::new(o))
.map(|o| o.to_arr())
.collect::<Vec<_>>();
let mut s = serializer.serialize_struct("NonEmptyFrontier", 3)?;
s.serialize_field(
"position",
&SerializeAsWrap::<_, FromInto<u64>>::new(&value.position()),
)?;
s.serialize_field(
"leaf",
&SerializeAsWrap::<_, ToFromBytesWrapper<T>>::new(&value.leaf()),
)?;
s.serialize_field("ommers", &ommers)?;
s.end()

let x = NonEmptyFrontierSer {
position: value.position().into(),
leaf: &value.leaf().to_arr(),
ommers: ommer.as_slice(),
};

x.serialize(serializer)
}
}
#[derive(Deserialize)]
struct NonEmptyFrontierDe {
pub position: u64,
pub leaf: [u8; 32],
pub ommers: Vec<[u8; 32]>,
}

impl<'de, T: ToFromBytes> DeserializeAs<'de, NonEmptyFrontier<T>> for NonEmptyFrontierWrapper {
impl<'de, T: TryFromArray<u8, 32, Error = E>, E: Display> DeserializeAs<'de, NonEmptyFrontier<T>>
for NonEmptyFrontierWrapper
{
fn deserialize_as<D>(deserializer: D) -> Result<NonEmptyFrontier<T>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor<T>(std::marker::PhantomData<T>);
impl<T> Visitor<T> {
fn new() -> Self {
Self(std::marker::PhantomData)
}
}
impl<'de, T: ToFromBytes> serde::de::Visitor<'de> for Visitor<T> {
type Value = NonEmptyFrontier<T>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct OrchardNote")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut position = None;
let mut leaf = None;
let mut ommers = None;
while let Some(key) = map.next_key()? {
match key {
"position" => {
position = Some(
map.next_value::<DeserializeAsWrap<Position, FromInto<u64>>>()?,
);
}
"leaf" => {
leaf = Some(
map.next_value::<DeserializeAsWrap<T, ToFromBytesWrapper<T>>>()?,
);
}
"ommers" => {
ommers = Some(
map.next_value::<Vec<DeserializeAsWrap<T, ToFromBytesWrapper<T>>>>(
)?,
);
}
_ => {
return Err(serde::de::Error::unknown_field(
key,
&["recipient", "value", "rho", "rseed"],
));
}
}
}
let position = position
.ok_or_else(|| serde::de::Error::missing_field("position"))?
.into_inner();
let leaf = leaf
.ok_or_else(|| serde::de::Error::missing_field("leaf"))?
.into_inner();
let ommers = ommers
.ok_or_else(|| serde::de::Error::missing_field("ommers"))?
.into_iter()
.map(|o| o.into_inner())
.collect();

NonEmptyFrontier::from_parts(position, leaf, ommers).map_err(|_e| {
serde::de::Error::custom("Failed to deserialize non-empty frontier")
})
}
}
deserializer.deserialize_struct(
"NonEmptyFrontier",
&["position", "leaf", "ommers"],
Visitor::<T>::new(),
let frontier = NonEmptyFrontierDe::deserialize(deserializer)?;
NonEmptyFrontier::from_parts(
frontier.position.into(),
T::from_arr(frontier.leaf).map_err(serde::de::Error::custom)?,
frontier
.ommers
.into_iter()
.map(|o| T::from_arr(o).map_err(serde::de::Error::custom))
.collect::<Result<Vec<_>, _>>()?,
)
.map_err(|_| serde::de::Error::custom("Failed to construct frontier from parts"))
}
}
Loading

0 comments on commit 9dd40df

Please sign in to comment.