Skip to content

Commit

Permalink
fix(backend): use custom Deserialize impl for postcard to handle `#[s…
Browse files Browse the repository at this point in the history
…erde(skip)]`
  • Loading branch information
Eason0729 committed Jul 23, 2024
1 parent 6846505 commit 4d98294
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
3 changes: 3 additions & 0 deletions backend/src/controller/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde::{de::DeserializeOwned, Serialize};
use crate::config::CONFIG;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use blake2::{Blake2b512, Digest};
use tracing::instrument;

type Result<T> = std::result::Result<T, Error>;

Expand Down Expand Up @@ -60,6 +61,7 @@ impl CryptoController {
/// Serialize and calculate checksum and return
///
/// Note that it shouldn't be an security measurement
#[instrument(skip_all, level = "debug", ret(level = "debug"))]
pub fn encode<M: Serialize>(&self, obj: M) -> Result<String> {
let mut raw = postcard::to_allocvec(&obj)?;

Expand All @@ -75,6 +77,7 @@ impl CryptoController {
/// check signature and return the object
///
/// Error if signature invaild
#[instrument(skip_all, level = "debug", err(level = "debug"))]
pub fn decode<M: DeserializeOwned>(&self, raw: String) -> Result<M> {
let mut raw = URL_SAFE_NO_PAD.decode(raw)?;

Expand Down
19 changes: 8 additions & 11 deletions backend/src/controller/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,16 @@ impl TokenController {
base64::Engine::decode(&base64::engine::general_purpose::STANDARD_NO_PAD, token)?;
let rand: Rand = rand.try_into().map_err(|_| Error::InvalidTokenLength)?;

let cache_result = {
match self.cache.get(&rand) {
Some(cc) => {
if cc.expiry < now {
self.cache.remove(&rand);
None
} else {
Some(cc.clone())
}
let cache_result = match self.cache.get(&rand) {
Some(cc) => {
if cc.expiry < now {
self.cache.remove(&rand);
None
} else {
Some(cc.clone())
}
None => None,
}
None => None,
};

let token = match cache_result {
Expand All @@ -149,7 +147,6 @@ impl TokenController {
token
}
None => {
// FIXME: this is cold branch!
let token: CachedToken = (token::Entity::find()
.filter(token::Column::Rand.eq(rand.to_vec()))
.one(self.db.deref())
Expand Down
24 changes: 20 additions & 4 deletions backend/src/entity/util/paginator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
use super::helper::*;
use crate::util::auth::Auth;
use sea_orm::*;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};
use tonic::async_trait;
use tracing::*;

Expand Down Expand Up @@ -360,17 +360,33 @@ impl<S: SortSource<R>, R: Reflect<S::Entity>> PaginateRaw for ColumnPaginator<S,
}
}

#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Default)]
pub enum UninitPaginator<P: PaginateRaw> {
#[serde(skip_deserializing, skip_serializing)]
#[serde(skip)]
Uninit(<P::Source as PagerData>::Data, bool),
#[serde(bound(deserialize = "P: for<'a> Deserialize<'a>"))]
Init(P),
#[serde(skip_deserializing, skip_serializing)]
#[serde(skip)]
#[default]
None,
}

impl<'de, P: PaginateRaw> Deserialize<'de> for UninitPaginator<P> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let p: Option<P> = Deserialize::deserialize(deserializer)?;

match p {
Some(p) => Ok(UninitPaginator::Init(p)),
None => Err(serde::de::Error::custom(
"Unexpected data format for UninitPaginator",
)),
}
}
}

impl<P: PaginateRaw> UninitPaginator<P> {
pub fn new(data: <P::Source as PagerData>::Data, start_from_end: bool) -> Self {
Self::Uninit(data, start_from_end)
Expand Down
8 changes: 6 additions & 2 deletions backend/src/util/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use std::num::NonZeroU32;

use super::auth::Auth;
use crate::{
controller::rate_limit::{Bucket, TrafficType},
server::Server,
};
use grpc::backend::{Id, *};
use tracing::*;

use super::auth::Auth;
use tracing_futures::Instrument;

impl Server {
/// parse authentication without rate limiting
///
/// It's useful for endpoints that require resolving identity
/// before rate limiting, such as logout
#[instrument(skip_all, level = "info")]
pub async fn parse_auth<T>(
&self,
req: &tonic::Request<T>,
Expand All @@ -25,6 +26,7 @@ impl Server {
.check(req, |req| async {
if let Some(x) = req.metadata().get("token") {
let token = x.to_str().unwrap();
tracing::debug!(token = token);

match self.token.verify(token).in_current_span().await {
Ok(user) => {
Expand All @@ -42,6 +44,7 @@ impl Server {
TrafficType::Guest
}
})
.in_current_span()
.await?;
tracing::info!(auth = %auth);
Ok((auth, bucket))
Expand Down Expand Up @@ -89,6 +92,7 @@ impl Server {
let (auth, bucket) = self.parse_auth(&req).in_current_span().await?;
bucket.cost(NonZeroU32::new(3).unwrap())?;
let req = req.into_inner();
tracing::debug!(bucket = %bucket);

if let Some(cost) = NonZeroU32::new(req.get_cost()) {
bucket.cost(cost)?;
Expand Down

0 comments on commit 4d98294

Please sign in to comment.