From 2db017af37dfedd49b854fb431ef629b8157eefc Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 30 Aug 2024 00:26:23 +0000 Subject: [PATCH 001/245] simplify service trait bounds and lifetimes Signed-off-by: Jason Volk --- src/service/service.rs | 23 +++++++++-------------- src/service/services.rs | 8 ++++---- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/service/service.rs b/src/service/service.rs index 635f782ea..065f78a00 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -51,7 +51,7 @@ pub(crate) struct Args<'a> { /// Dep is a reference to a service used within another service. /// Circular-dependencies between services require this indirection. -pub(crate) struct Dep { +pub(crate) struct Dep { dep: OnceLock>, service: Weak, name: &'static str, @@ -62,7 +62,7 @@ pub(crate) type MapType = BTreeMap; pub(crate) type MapVal = (Weak, Weak); pub(crate) type MapKey = String; -impl Deref for Dep { +impl Deref for Dep { type Target = Arc; /// Dereference a dependency. The dependency must be ready or panics. @@ -80,7 +80,7 @@ impl Deref for Dep { impl<'a> Args<'a> { /// Create a lazy-reference to a service when constructing another Service. - pub(crate) fn depend(&'a self, name: &'static str) -> Dep { + pub(crate) fn depend(&'a self, name: &'static str) -> Dep { Dep:: { dep: OnceLock::new(), service: Arc::downgrade(self.service), @@ -90,17 +90,12 @@ impl<'a> Args<'a> { /// Create a reference immediately to a service when constructing another /// Service. The other service must be constructed. - pub(crate) fn require(&'a self, name: &'static str) -> Arc { - require::(self.service, name) - } + pub(crate) fn require(&'a self, name: &str) -> Arc { require::(self.service, name) } } /// Reference a Service by name. Panics if the Service does not exist or was /// incorrectly cast. -pub(crate) fn require<'a, 'b, T>(map: &'b Map, name: &'a str) -> Arc -where - T: Send + Sync + 'a + 'b + 'static, -{ +pub(crate) fn require(map: &Map, name: &str) -> Arc { try_get::(map, name) .inspect_err(inspect_log) .expect("Failure to reference service required by another service.") @@ -112,9 +107,9 @@ where /// # Panics /// Incorrect type is not a silent failure (None) as the type never has a reason /// to be incorrect. -pub(crate) fn get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Option> +pub(crate) fn get(map: &Map, name: &str) -> Option> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { map.read() .expect("locked for reading") @@ -129,9 +124,9 @@ where /// Reference a Service by name. Returns Err if the Service does not exist or /// was incorrectly cast. -pub(crate) fn try_get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Result> +pub(crate) fn try_get(map: &Map, name: &str) -> Result> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { map.read() .expect("locked for reading") diff --git a/src/service/services.rs b/src/service/services.rs index 8e69cdbb6..3aa095b85 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -193,16 +193,16 @@ impl Services { } } - pub fn try_get<'a, 'b, T>(&'b self, name: &'a str) -> Result> + pub fn try_get(&self, name: &str) -> Result> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { service::try_get::(&self.service, name) } - pub fn get<'a, 'b, T>(&'b self, name: &'a str) -> Option> + pub fn get(&self, name: &str) -> Option> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { service::get::(&self.service, name) } From 99ad404ea9f72b3a4d7aabb55a17127d82f39d12 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 31 Aug 2024 02:13:23 +0000 Subject: [PATCH 002/245] add str traits for split, between, unquote; consolidate tests Signed-off-by: Jason Volk --- src/core/utils/string.rs | 13 +++--- src/core/utils/string/between.rs | 26 ++++++++++++ src/core/utils/string/split.rs | 22 ++++++++++ src/core/utils/string/tests.rs | 70 +++++++++++++++++++++++++++++++ src/core/utils/string/unquote.rs | 33 +++++++++++++++ src/core/utils/string/unquoted.rs | 52 +++++++++++++++++++++++ src/core/utils/tests.rs | 43 ------------------- src/service/service.rs | 4 +- 8 files changed, 212 insertions(+), 51 deletions(-) create mode 100644 src/core/utils/string/between.rs create mode 100644 src/core/utils/string/split.rs create mode 100644 src/core/utils/string/tests.rs create mode 100644 src/core/utils/string/unquote.rs create mode 100644 src/core/utils/string/unquoted.rs diff --git a/src/core/utils/string.rs b/src/core/utils/string.rs index 85282b30a..e65a33698 100644 --- a/src/core/utils/string.rs +++ b/src/core/utils/string.rs @@ -1,3 +1,10 @@ +mod between; +mod split; +mod tests; +mod unquote; +mod unquoted; + +pub use self::{between::Between, split::SplitInfallible, unquote::Unquote, unquoted::Unquoted}; use crate::{utils::exchange, Result}; pub const EMPTY: &str = ""; @@ -95,12 +102,6 @@ pub fn common_prefix<'a>(choice: &'a [&str]) -> &'a str { }) } -#[inline] -#[must_use] -pub fn split_once_infallible<'a>(input: &'a str, delim: &'_ str) -> (&'a str, &'a str) { - input.split_once(delim).unwrap_or((input, EMPTY)) -} - /// Parses the bytes into a string. pub fn string_from_bytes(bytes: &[u8]) -> Result { let str: &str = str_from_bytes(bytes)?; diff --git a/src/core/utils/string/between.rs b/src/core/utils/string/between.rs new file mode 100644 index 000000000..209a9dabb --- /dev/null +++ b/src/core/utils/string/between.rs @@ -0,0 +1,26 @@ +type Delim<'a> = (&'a str, &'a str); + +/// Slice a string between a pair of delimeters. +pub trait Between<'a> { + /// Extract a string between the delimeters. If the delimeters were not + /// found None is returned, otherwise the first extraction is returned. + fn between(&self, delim: Delim<'_>) -> Option<&'a str>; + + /// Extract a string between the delimeters. If the delimeters were not + /// found the original string is returned; take note of this behavior, + /// if an empty slice is desired for this case use the fallible version and + /// unwrap to EMPTY. + fn between_infallible(&self, delim: Delim<'_>) -> &'a str; +} + +impl<'a> Between<'a> for &'a str { + #[inline] + fn between_infallible(&self, delim: Delim<'_>) -> &'a str { self.between(delim).unwrap_or(self) } + + #[inline] + fn between(&self, delim: Delim<'_>) -> Option<&'a str> { + self.split_once(delim.0) + .and_then(|(_, b)| b.rsplit_once(delim.1)) + .map(|(a, _)| a) + } +} diff --git a/src/core/utils/string/split.rs b/src/core/utils/string/split.rs new file mode 100644 index 000000000..96de28dff --- /dev/null +++ b/src/core/utils/string/split.rs @@ -0,0 +1,22 @@ +use super::EMPTY; + +type Pair<'a> = (&'a str, &'a str); + +/// Split a string with default behaviors on non-match. +pub trait SplitInfallible<'a> { + /// Split a string at the first occurrence of delim. If not found, the + /// entire string is returned in \[0\], while \[1\] is empty. + fn split_once_infallible(&self, delim: &str) -> Pair<'a>; + + /// Split a string from the last occurrence of delim. If not found, the + /// entire string is returned in \[0\], while \[1\] is empty. + fn rsplit_once_infallible(&self, delim: &str) -> Pair<'a>; +} + +impl<'a> SplitInfallible<'a> for &'a str { + #[inline] + fn rsplit_once_infallible(&self, delim: &str) -> Pair<'a> { self.rsplit_once(delim).unwrap_or((self, EMPTY)) } + + #[inline] + fn split_once_infallible(&self, delim: &str) -> Pair<'a> { self.split_once(delim).unwrap_or((self, EMPTY)) } +} diff --git a/src/core/utils/string/tests.rs b/src/core/utils/string/tests.rs new file mode 100644 index 000000000..e8c17de6d --- /dev/null +++ b/src/core/utils/string/tests.rs @@ -0,0 +1,70 @@ +#![cfg(test)] + +#[test] +fn common_prefix() { + let input = ["conduwuit", "conduit", "construct"]; + let output = super::common_prefix(&input); + assert_eq!(output, "con"); +} + +#[test] +fn common_prefix_empty() { + let input = ["abcdefg", "hijklmn", "opqrstu"]; + let output = super::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn common_prefix_none() { + let input = []; + let output = super::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn camel_to_snake_case_0() { + let res = super::camel_to_snake_string("CamelToSnakeCase"); + assert_eq!(res, "camel_to_snake_case"); +} + +#[test] +fn camel_to_snake_case_1() { + let res = super::camel_to_snake_string("CAmelTOSnakeCase"); + assert_eq!(res, "camel_tosnake_case"); +} + +#[test] +fn unquote() { + use super::Unquote; + + assert_eq!("\"foo\"".unquote(), Some("foo")); + assert_eq!("\"foo".unquote(), None); + assert_eq!("foo".unquote(), None); +} + +#[test] +fn unquote_infallible() { + use super::Unquote; + + assert_eq!("\"foo\"".unquote_infallible(), "foo"); + assert_eq!("\"foo".unquote_infallible(), "\"foo"); + assert_eq!("foo".unquote_infallible(), "foo"); +} + +#[test] +fn between() { + use super::Between; + + assert_eq!("\"foo\"".between(("\"", "\"")), Some("foo")); + assert_eq!("\"foo".between(("\"", "\"")), None); + assert_eq!("foo".between(("\"", "\"")), None); +} + +#[test] +fn between_infallible() { + use super::Between; + + assert_eq!("\"foo\"".between_infallible(("\"", "\"")), "foo"); + assert_eq!("\"foo".between_infallible(("\"", "\"")), "\"foo"); + assert_eq!("foo".between_infallible(("\"", "\"")), "foo"); +} diff --git a/src/core/utils/string/unquote.rs b/src/core/utils/string/unquote.rs new file mode 100644 index 000000000..eeded610a --- /dev/null +++ b/src/core/utils/string/unquote.rs @@ -0,0 +1,33 @@ +const QUOTE: char = '"'; + +/// Slice a string between quotes +pub trait Unquote<'a> { + /// Whether the input is quoted. If this is false the fallible methods of + /// this interface will fail. + fn is_quoted(&self) -> bool; + + /// Unquotes a string. If the input is not quoted it is simply returned + /// as-is. If the input is partially quoted on either end that quote is not + /// removed. + fn unquote(&self) -> Option<&'a str>; + + /// Unquotes a string. The input must be quoted on each side for Some to be + /// returned + fn unquote_infallible(&self) -> &'a str; +} + +impl<'a> Unquote<'a> for &'a str { + #[inline] + fn unquote_infallible(&self) -> &'a str { + self.strip_prefix(QUOTE) + .unwrap_or(self) + .strip_suffix(QUOTE) + .unwrap_or(self) + } + + #[inline] + fn unquote(&self) -> Option<&'a str> { self.strip_prefix(QUOTE).and_then(|s| s.strip_suffix(QUOTE)) } + + #[inline] + fn is_quoted(&self) -> bool { self.starts_with(QUOTE) && self.ends_with(QUOTE) } +} diff --git a/src/core/utils/string/unquoted.rs b/src/core/utils/string/unquoted.rs new file mode 100644 index 000000000..5b002d99b --- /dev/null +++ b/src/core/utils/string/unquoted.rs @@ -0,0 +1,52 @@ +use std::ops::Deref; + +use serde::{de, Deserialize, Deserializer}; + +use super::Unquote; +use crate::{err, Result}; + +/// Unquoted string which deserialized from a quoted string. Construction from a +/// &str is infallible such that the input can already be unquoted. Construction +/// from serde deserialization is fallible and the input must be quoted. +#[repr(transparent)] +pub struct Unquoted(str); + +impl<'a> Unquoted { + #[inline] + #[must_use] + pub fn as_str(&'a self) -> &'a str { &self.0 } +} + +impl<'a, 'de: 'a> Deserialize<'de> for &'a Unquoted { + fn deserialize>(deserializer: D) -> Result { + let s = <&'a str>::deserialize(deserializer)?; + s.is_quoted() + .then_some(s) + .ok_or(err!(SerdeDe("expected quoted string"))) + .map_err(de::Error::custom) + .map(Into::into) + } +} + +impl<'a> From<&'a str> for &'a Unquoted { + fn from(s: &'a str) -> &'a Unquoted { + let s: &'a str = s.unquote_infallible(); + + //SAFETY: This is a pattern I lifted from ruma-identifiers for strong-type strs + // by wrapping in a tuple-struct. + #[allow(clippy::transmute_ptr_to_ptr)] + unsafe { + std::mem::transmute(s) + } + } +} + +impl Deref for Unquoted { + type Target = str; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl<'a> AsRef for &'a Unquoted { + fn as_ref(&self) -> &'a str { &self.0 } +} diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index e91accdf4..5880470a3 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -36,33 +36,6 @@ fn increment_wrap() { assert_eq!(res, 0); } -#[test] -fn common_prefix() { - use utils::string; - - let input = ["conduwuit", "conduit", "construct"]; - let output = string::common_prefix(&input); - assert_eq!(output, "con"); -} - -#[test] -fn common_prefix_empty() { - use utils::string; - - let input = ["abcdefg", "hijklmn", "opqrstu"]; - let output = string::common_prefix(&input); - assert_eq!(output, ""); -} - -#[test] -fn common_prefix_none() { - use utils::string; - - let input = []; - let output = string::common_prefix(&input); - assert_eq!(output, ""); -} - #[test] fn checked_add() { use crate::checked; @@ -134,19 +107,3 @@ async fn mutex_map_contend() { tokio::try_join!(join_b, join_a).expect("joined"); assert!(map.is_empty(), "Must be empty"); } - -#[test] -fn camel_to_snake_case_0() { - use utils::string::camel_to_snake_string; - - let res = camel_to_snake_string("CamelToSnakeCase"); - assert_eq!(res, "camel_to_snake_case"); -} - -#[test] -fn camel_to_snake_case_1() { - use utils::string::camel_to_snake_string; - - let res = camel_to_snake_string("CAmelTOSnakeCase"); - assert_eq!(res, "camel_tosnake_case"); -} diff --git a/src/service/service.rs b/src/service/service.rs index 065f78a00..031650506 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -7,7 +7,7 @@ use std::{ }; use async_trait::async_trait; -use conduit::{err, error::inspect_log, utils::string::split_once_infallible, Err, Result, Server}; +use conduit::{err, error::inspect_log, utils::string::SplitInfallible, Err, Result, Server}; use database::Database; /// Abstract interface for a Service @@ -147,4 +147,4 @@ where /// Utility for service implementations; see Service::name() in the trait. #[inline] -pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } +pub(crate) fn make_name(module_path: &str) -> &str { module_path.split_once_infallible("::").1 } From 2709995f84cfa9dcaab14ac9ee856aafe06b22c0 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 1 Sep 2024 01:53:22 +0000 Subject: [PATCH 003/245] add MapExpect to Result add DebugInspect to Result move Result typedef into unit Signed-off-by: Jason Volk --- src/core/debug.rs | 6 ++-- src/core/mod.rs | 4 +-- src/core/result.rs | 6 ++++ src/core/result/debug_inspect.rs | 52 ++++++++++++++++++++++++++++++++ src/core/result/map_expect.rs | 15 +++++++++ 5 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 src/core/result.rs create mode 100644 src/core/result/debug_inspect.rs create mode 100644 src/core/result/map_expect.rs diff --git a/src/core/debug.rs b/src/core/debug.rs index 844445d53..1e36ca8e2 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -1,10 +1,10 @@ use std::{any::Any, panic}; -/// Export debug proc_macros +// Export debug proc_macros pub use conduit_macros::recursion_depth; -/// Export all of the ancillary tools from here as well. -pub use crate::utils::debug::*; +// Export all of the ancillary tools from here as well. +pub use crate::{result::DebugInspect, utils::debug::*}; /// Log event at given level in debug-mode (when debug-assertions are enabled). /// In release-mode it becomes DEBUG level, and possibly subject to elision. diff --git a/src/core/mod.rs b/src/core/mod.rs index 9898243bf..31851f4f0 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -7,6 +7,7 @@ pub mod log; pub mod metrics; pub mod mods; pub mod pdu; +pub mod result; pub mod server; pub mod utils; @@ -15,13 +16,12 @@ pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; pub use pdu::{PduBuilder, PduCount, PduEvent}; +pub use result::Result; pub use server::Server; pub use utils::{ctor, dtor, implement}; pub use crate as conduit_core; -pub type Result = std::result::Result; - rustc_flags_capture! {} #[cfg(not(conduit_mods))] diff --git a/src/core/result.rs b/src/core/result.rs new file mode 100644 index 000000000..d58467cf5 --- /dev/null +++ b/src/core/result.rs @@ -0,0 +1,6 @@ +mod debug_inspect; +mod map_expect; + +pub use self::{debug_inspect::DebugInspect, map_expect::MapExpect}; + +pub type Result = std::result::Result; diff --git a/src/core/result/debug_inspect.rs b/src/core/result/debug_inspect.rs new file mode 100644 index 000000000..ef80979d8 --- /dev/null +++ b/src/core/result/debug_inspect.rs @@ -0,0 +1,52 @@ +use super::Result; + +/// Inspect Result values with release-mode elision. +pub trait DebugInspect { + /// Inspects an Err contained value in debug-mode. In release-mode closure F + /// is elided. + #[must_use] + fn debug_inspect_err(self, f: F) -> Self; + + /// Inspects an Ok contained value in debug-mode. In release-mode closure F + /// is elided. + #[must_use] + fn debug_inspect(self, f: F) -> Self; +} + +#[cfg(debug_assertions)] +impl DebugInspect for Result { + #[inline] + fn debug_inspect(self, f: F) -> Self + where + F: FnOnce(&T), + { + self.inspect(f) + } + + #[inline] + fn debug_inspect_err(self, f: F) -> Self + where + F: FnOnce(&E), + { + self.inspect_err(f) + } +} + +#[cfg(not(debug_assertions))] +impl DebugInspect for Result { + #[inline] + fn debug_inspect(self, _: F) -> Self + where + F: FnOnce(&T), + { + self + } + + #[inline] + fn debug_inspect_err(self, _: F) -> Self + where + F: FnOnce(&E), + { + self + } +} diff --git a/src/core/result/map_expect.rs b/src/core/result/map_expect.rs new file mode 100644 index 000000000..8ce9195fe --- /dev/null +++ b/src/core/result/map_expect.rs @@ -0,0 +1,15 @@ +use std::fmt::Debug; + +use super::Result; + +pub trait MapExpect { + /// Calls expect(msg) on the mapped Result value. This is similar to + /// map(Result::unwrap) but composes an expect call and message without + /// requiring a closure. + fn map_expect(self, msg: &str) -> Option; +} + +impl MapExpect for Option> { + #[inline] + fn map_expect(self, msg: &str) -> Option { self.map(|result| result.expect(msg)) } +} From 3d4b0f10a59008a5cc785fd299937cdbee0dacd1 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 7 Sep 2024 22:04:28 +0000 Subject: [PATCH 004/245] add expected! macro to checked math expression suite Signed-off-by: Jason Volk --- src/core/utils/math.rs | 36 +++++++++++++++++++++-------- src/service/rooms/auth_chain/mod.rs | 2 +- src/service/rooms/timeline/mod.rs | 2 +- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs index f9d0de302..8c4b01bed 100644 --- a/src/core/utils/math.rs +++ b/src/core/utils/math.rs @@ -7,32 +7,50 @@ use crate::{debug::type_name, err, Err, Error, Result}; /// Checked arithmetic expression. Returns a Result #[macro_export] macro_rules! checked { - ($($input:tt)*) => { - $crate::utils::math::checked_ops!($($input)*) + ($($input:tt)+) => { + $crate::utils::math::checked_ops!($($input)+) .ok_or_else(|| $crate::err!(Arithmetic("operation overflowed or result invalid"))) - } + }; } -/// in release-mode. Use for performance when the expression is obviously safe. -/// The check remains in debug-mode for regression analysis. +/// Checked arithmetic expression which panics on failure. This is for +/// expressions which do not meet the threshold for validated! but the caller +/// has no realistic expectation for error and no interest in cluttering the +/// callsite with result handling from checked!. +#[macro_export] +macro_rules! expected { + ($msg:literal, $($input:tt)+) => { + $crate::checked!($($input)+).expect($msg) + }; + + ($($input:tt)+) => { + $crate::expected!("arithmetic expression expectation failure", $($input)+) + }; +} + +/// Unchecked arithmetic expression in release-mode. Use for performance when +/// the expression is obviously safe. The check remains in debug-mode for +/// regression analysis. #[cfg(not(debug_assertions))] #[macro_export] macro_rules! validated { - ($($input:tt)*) => { + ($($input:tt)+) => { //#[allow(clippy::arithmetic_side_effects)] { //Some($($input)*) // .ok_or_else(|| $crate::err!(Arithmetic("this error should never been seen"))) //} //NOTE: remove me when stmt_expr_attributes is stable - $crate::checked!($($input)*) - } + $crate::expected!("validated arithmetic expression failed", $($input)+) + }; } +/// Checked arithmetic expression in debug-mode. Use for performance when +/// the expression is obviously safe. The check is elided in release-mode. #[cfg(debug_assertions)] #[macro_export] macro_rules! validated { - ($($input:tt)*) => { $crate::checked!($($input)*) } + ($($input:tt)+) => { $crate::expected!($($input)+) } } /// Returns false if the exponential backoff has expired based on the inputs diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 9a1e7e67a..d0bc425fc 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -66,7 +66,7 @@ impl Service { .enumerate() { let bucket: usize = short.try_into()?; - let bucket: usize = validated!(bucket % NUM_BUCKETS)?; + let bucket: usize = validated!(bucket % NUM_BUCKETS); buckets[bucket].insert((short, starting_events[i])); } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 4f2352f81..04d9559da 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1205,7 +1205,7 @@ impl Service { let count = self.services.globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); + pdu_id.extend_from_slice(&(validated!(max - count)).to_be_bytes()); // Insert pdu self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; From aa265f7ca4ee5f6b15cce83a235cf5f9c4317cfc Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 8 Sep 2024 04:39:27 +0000 Subject: [PATCH 005/245] add err log trait to Result Signed-off-by: Jason Volk --- src/core/error/log.rs | 33 +++++++++++++++--- src/core/result.rs | 4 ++- src/core/result/inspect_log.rs | 60 ++++++++++++++++++++++++++++++++ src/core/result/log_debug_err.rs | 36 +++++++++++++++++++ src/core/result/log_err.rs | 36 +++++++++++++++++++ 5 files changed, 163 insertions(+), 6 deletions(-) create mode 100644 src/core/result/inspect_log.rs create mode 100644 src/core/result/log_debug_err.rs create mode 100644 src/core/result/log_err.rs diff --git a/src/core/error/log.rs b/src/core/error/log.rs index c272bf730..60bd70140 100644 --- a/src/core/error/log.rs +++ b/src/core/error/log.rs @@ -1,7 +1,8 @@ use std::{convert::Infallible, fmt}; +use tracing::Level; + use super::Error; -use crate::{debug_error, error}; #[inline] pub fn else_log(error: E) -> Result @@ -64,11 +65,33 @@ where } #[inline] -pub fn inspect_log(error: &E) { - error!("{error}"); +pub fn inspect_log(error: &E) { inspect_log_level(error, Level::ERROR); } + +#[inline] +pub fn inspect_debug_log(error: &E) { inspect_debug_log_level(error, Level::ERROR); } + +#[inline] +pub fn inspect_log_level(error: &E, level: Level) { + use crate::{debug, error, info, trace, warn}; + + match level { + Level::ERROR => error!("{error}"), + Level::WARN => warn!("{error}"), + Level::INFO => info!("{error}"), + Level::DEBUG => debug!("{error}"), + Level::TRACE => trace!("{error}"), + } } #[inline] -pub fn inspect_debug_log(error: &E) { - debug_error!("{error:?}"); +pub fn inspect_debug_log_level(error: &E, level: Level) { + use crate::{debug, debug_error, debug_info, debug_warn, trace}; + + match level { + Level::ERROR => debug_error!("{error:?}"), + Level::WARN => debug_warn!("{error:?}"), + Level::INFO => debug_info!("{error:?}"), + Level::DEBUG => debug!("{error:?}"), + Level::TRACE => trace!("{error:?}"), + } } diff --git a/src/core/result.rs b/src/core/result.rs index d58467cf5..c3eaf95b2 100644 --- a/src/core/result.rs +++ b/src/core/result.rs @@ -1,6 +1,8 @@ mod debug_inspect; +mod log_debug_err; +mod log_err; mod map_expect; -pub use self::{debug_inspect::DebugInspect, map_expect::MapExpect}; +pub use self::{debug_inspect::DebugInspect, log_debug_err::LogDebugErr, log_err::LogErr, map_expect::MapExpect}; pub type Result = std::result::Result; diff --git a/src/core/result/inspect_log.rs b/src/core/result/inspect_log.rs new file mode 100644 index 000000000..577761c5c --- /dev/null +++ b/src/core/result/inspect_log.rs @@ -0,0 +1,60 @@ +use std::fmt; + +use tracing::Level; + +use super::Result; +use crate::error; + +pub trait ErrLog +where + E: fmt::Display, +{ + fn log_err(self, level: Level) -> Self; + + fn err_log(self) -> Self + where + Self: Sized, + { + self.log_err(Level::ERROR) + } +} + +pub trait ErrDebugLog +where + E: fmt::Debug, +{ + fn log_err_debug(self, level: Level) -> Self; + + fn err_debug_log(self) -> Self + where + Self: Sized, + { + self.log_err_debug(Level::ERROR) + } +} + +impl ErrLog for Result +where + E: fmt::Display, +{ + #[inline] + fn log_err(self, level: Level) -> Self + where + Self: Sized, + { + self.inspect_err(|error| error::inspect_log_level(&error, level)) + } +} + +impl ErrDebugLog for Result +where + E: fmt::Debug, +{ + #[inline] + fn log_err_debug(self, level: Level) -> Self + where + Self: Sized, + { + self.inspect_err(|error| error::inspect_debug_log_level(&error, level)) + } +} diff --git a/src/core/result/log_debug_err.rs b/src/core/result/log_debug_err.rs new file mode 100644 index 000000000..be2000aed --- /dev/null +++ b/src/core/result/log_debug_err.rs @@ -0,0 +1,36 @@ +use std::fmt; + +use tracing::Level; + +use super::{DebugInspect, Result}; +use crate::error; + +pub trait LogDebugErr +where + E: fmt::Debug, +{ + #[must_use] + fn err_debug_log(self, level: Level) -> Self; + + #[inline] + #[must_use] + fn log_debug_err(self) -> Self + where + Self: Sized, + { + self.err_debug_log(Level::ERROR) + } +} + +impl LogDebugErr for Result +where + E: fmt::Debug, +{ + #[inline] + fn err_debug_log(self, level: Level) -> Self + where + Self: Sized, + { + self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level)) + } +} diff --git a/src/core/result/log_err.rs b/src/core/result/log_err.rs new file mode 100644 index 000000000..079571f56 --- /dev/null +++ b/src/core/result/log_err.rs @@ -0,0 +1,36 @@ +use std::fmt; + +use tracing::Level; + +use super::Result; +use crate::error; + +pub trait LogErr +where + E: fmt::Display, +{ + #[must_use] + fn err_log(self, level: Level) -> Self; + + #[inline] + #[must_use] + fn log_err(self) -> Self + where + Self: Sized, + { + self.err_log(Level::ERROR) + } +} + +impl LogErr for Result +where + E: fmt::Display, +{ + #[inline] + fn err_log(self, level: Level) -> Self + where + Self: Sized, + { + self.inspect_err(|error| error::inspect_log_level(&error, level)) + } +} From bd75ff65c96427429a0334b6d899f2b630fe5f8a Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 8 Sep 2024 06:53:15 +0000 Subject: [PATCH 006/245] move common_elements util into unit Signed-off-by: Jason Volk --- src/core/utils/algorithm.rs | 25 +++++++++++++++++++++++++ src/core/utils/mod.rs | 28 +++------------------------- 2 files changed, 28 insertions(+), 25 deletions(-) create mode 100644 src/core/utils/algorithm.rs diff --git a/src/core/utils/algorithm.rs b/src/core/utils/algorithm.rs new file mode 100644 index 000000000..9bc1bc8a7 --- /dev/null +++ b/src/core/utils/algorithm.rs @@ -0,0 +1,25 @@ +use std::cmp::Ordering; + +#[allow(clippy::impl_trait_in_params)] +pub fn common_elements( + mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, +) -> Option>> { + let first_iterator = iterators.next()?; + let mut other_iterators = iterators.map(Iterator::peekable).collect::>(); + + Some(first_iterator.filter(move |target| { + other_iterators.iter_mut().all(|it| { + while let Some(element) = it.peek() { + match check_order(element, target) { + Ordering::Greater => return false, // We went too far + Ordering::Equal => return true, // Element is in both iters + Ordering::Less => { + // Keep searching + it.next(); + }, + } + } + false + }) + })) +} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 1556646ec..29d0b87b2 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod algorithm; pub mod bytes; pub mod content_disposition; pub mod debug; @@ -13,9 +14,10 @@ pub mod sys; mod tests; pub mod time; -use std::cmp::{self, Ordering}; +use std::cmp; pub use ::ctor::{ctor, dtor}; +pub use algorithm::common_elements; pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; pub use conduit_macros::implement; pub use debug::slice_truncated as debug_slice_truncated; @@ -47,27 +49,3 @@ pub fn generate_keypair() -> Vec { ); value } - -#[allow(clippy::impl_trait_in_params)] -pub fn common_elements( - mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, -) -> Option>> { - let first_iterator = iterators.next()?; - let mut other_iterators = iterators.map(Iterator::peekable).collect::>(); - - Some(first_iterator.filter(move |target| { - other_iterators.iter_mut().all(|it| { - while let Some(element) = it.peek() { - match check_order(element, target) { - Ordering::Greater => return false, // We went too far - Ordering::Equal => return true, // Element is in both iters - Ordering::Less => { - // Keep searching - it.next(); - }, - } - } - false - }) - })) -} From 63053640f1e5789719b3f88f92b3006ae3caecf4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 13 Sep 2024 18:24:33 +0000 Subject: [PATCH 007/245] add util functors for is_zero/is_equal; move clamp to math utils Signed-off-by: Jason Volk --- src/core/utils/math.rs | 35 +++++++++++++++++++++++++++++++++++ src/core/utils/mod.rs | 12 ++---------- src/macros/utils.rs | 7 ++----- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs index 8c4b01bed..215de339c 100644 --- a/src/core/utils/math.rs +++ b/src/core/utils/math.rs @@ -53,6 +53,38 @@ macro_rules! validated { ($($input:tt)+) => { $crate::expected!($($input)+) } } +/// Functor for equality to zero +#[macro_export] +macro_rules! is_zero { + () => { + $crate::is_matching!(0) + }; +} + +/// Functor for equality i.e. .is_some_and(is_equal!(2)) +#[macro_export] +macro_rules! is_equal_to { + ($val:expr) => { + |x| (x == $val) + }; +} + +/// Functor for less i.e. .is_some_and(is_less_than!(2)) +#[macro_export] +macro_rules! is_less_than { + ($val:expr) => { + |x| (x < $val) + }; +} + +/// Functor for matches! i.e. .is_some_and(is_matching!('A'..='Z')) +#[macro_export] +macro_rules! is_matching { + ($val:expr) => { + |x| matches!(x, $val) + }; +} + /// Returns false if the exponential backoff has expired based on the inputs #[inline] #[must_use] @@ -118,3 +150,6 @@ fn try_into_err, Src>(e: >::Error) -> Erro type_name::() )) } + +#[inline] +pub fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 29d0b87b2..03b755e9e 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -14,8 +14,6 @@ pub mod sys; mod tests; pub mod time; -use std::cmp; - pub use ::ctor::{ctor, dtor}; pub use algorithm::common_elements; pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; @@ -24,6 +22,7 @@ pub use debug::slice_truncated as debug_slice_truncated; pub use hash::calculate_hash; pub use html::Escape as HtmlEscape; pub use json::{deserialize_from_str, to_canonical_object}; +pub use math::clamp; pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; pub use rand::string as random_string; pub use string::{str_from_bytes, string_from_bytes}; @@ -31,14 +30,7 @@ pub use sys::available_parallelism; pub use time::now_millis as millis_since_unix_epoch; #[inline] -pub fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } - -#[inline] -pub fn exchange(state: &mut T, source: T) -> T { - let ret = state.clone(); - *state = source; - ret -} +pub fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } #[must_use] pub fn generate_keypair() -> Vec { diff --git a/src/macros/utils.rs b/src/macros/utils.rs index 58074e3a0..197dd90e9 100644 --- a/src/macros/utils.rs +++ b/src/macros/utils.rs @@ -41,8 +41,5 @@ pub(crate) fn camel_to_snake_string(s: &str) -> String { output } -pub(crate) fn exchange(state: &mut T, source: T) -> T { - let ret = state.clone(); - *state = source; - ret -} +#[inline] +pub(crate) fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } From a5822ebc274ddf92ba75f4f62d4e527078447baf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 13 Sep 2024 18:55:56 +0000 Subject: [PATCH 008/245] add missing err! case Signed-off-by: Jason Volk --- src/core/error/err.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/core/error/err.rs b/src/core/error/err.rs index b3d0240ed..4972e92a8 100644 --- a/src/core/error/err.rs +++ b/src/core/error/err.rs @@ -85,6 +85,10 @@ macro_rules! err { $crate::error::Error::$variant($crate::err_log!(buf, $level, $($args)+)) }}; + ($variant:ident($($args:ident),+)) => { + $crate::error::Error::$variant($($args),+) + }; + ($variant:ident($($args:tt)+)) => { $crate::error::Error::$variant($crate::format_maybe!($($args)+)) }; From f7ce4db0b00bc24be6c127895e16ba29a248c8a4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 12 Sep 2024 00:01:25 +0000 Subject: [PATCH 009/245] add is_not_found functor to error; tweak status code matcher Signed-off-by: Jason Volk --- src/core/error/mod.rs | 19 +++++++++++++------ src/core/result.rs | 6 +++++- src/core/result/not_found.rs | 12 ++++++++++++ 3 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 src/core/result/not_found.rs diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 92dbdfe3b..48b9b58ff 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -141,19 +141,22 @@ impl Error { use ruma::api::client::error::ErrorKind::Unknown; match self { - Self::Federation(_, error) => response::ruma_error_kind(error).clone(), + Self::Federation(_, error) | Self::Ruma(error) => response::ruma_error_kind(error).clone(), Self::BadRequest(kind, ..) | Self::Request(kind, ..) => kind.clone(), _ => Unknown, } } pub fn status_code(&self) -> http::StatusCode { + use http::StatusCode; + match self { - Self::Federation(_, ref error) | Self::Ruma(ref error) => error.status_code, - Self::Request(ref kind, _, code) => response::status_code(kind, *code), - Self::BadRequest(ref kind, ..) => response::bad_request_code(kind), - Self::Conflict(_) => http::StatusCode::CONFLICT, - _ => http::StatusCode::INTERNAL_SERVER_ERROR, + Self::Federation(_, error) | Self::Ruma(error) => error.status_code, + Self::Request(kind, _, code) => response::status_code(kind, *code), + Self::BadRequest(kind, ..) => response::bad_request_code(kind), + Self::Reqwest(error) => error.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + Self::Conflict(_) => StatusCode::CONFLICT, + _ => StatusCode::INTERNAL_SERVER_ERROR, } } } @@ -176,3 +179,7 @@ impl From for Error { pub fn infallible(_e: &Infallible) { panic!("infallible error should never exist"); } + +#[inline] +#[must_use] +pub fn is_not_found(e: &Error) -> bool { e.status_code() == http::StatusCode::NOT_FOUND } diff --git a/src/core/result.rs b/src/core/result.rs index c3eaf95b2..41d1d66c6 100644 --- a/src/core/result.rs +++ b/src/core/result.rs @@ -2,7 +2,11 @@ mod debug_inspect; mod log_debug_err; mod log_err; mod map_expect; +mod not_found; -pub use self::{debug_inspect::DebugInspect, log_debug_err::LogDebugErr, log_err::LogErr, map_expect::MapExpect}; +pub use self::{ + debug_inspect::DebugInspect, log_debug_err::LogDebugErr, log_err::LogErr, map_expect::MapExpect, + not_found::NotFound, +}; pub type Result = std::result::Result; diff --git a/src/core/result/not_found.rs b/src/core/result/not_found.rs new file mode 100644 index 000000000..69ce821b8 --- /dev/null +++ b/src/core/result/not_found.rs @@ -0,0 +1,12 @@ +use super::Result; +use crate::{error, Error}; + +pub trait NotFound { + #[must_use] + fn is_not_found(&self) -> bool; +} + +impl NotFound for Result { + #[inline] + fn is_not_found(&self) -> bool { self.as_ref().is_err_and(error::is_not_found) } +} From a5de27442a0c68a3ff2b86d15e82e5065884b787 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 12 Sep 2024 00:59:08 +0000 Subject: [PATCH 010/245] re-export crates used by error macros Signed-off-by: Jason Volk --- src/core/error/err.rs | 30 +++++++++++++++--------------- src/core/mod.rs | 3 +++ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/core/error/err.rs b/src/core/error/err.rs index 4972e92a8..82bb40b05 100644 --- a/src/core/error/err.rs +++ b/src/core/error/err.rs @@ -44,34 +44,34 @@ macro_rules! err { (Request(Forbidden($level:ident!($($args:tt)+)))) => {{ let mut buf = String::new(); $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::ruma::api::client::error::ErrorKind::forbidden(), $crate::err_log!(buf, $level, $($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }}; (Request(Forbidden($($args:tt)+))) => { $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::ruma::api::client::error::ErrorKind::forbidden(), $crate::format_maybe!($($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }; (Request($variant:ident($level:ident!($($args:tt)+)))) => {{ let mut buf = String::new(); $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::$variant, + $crate::ruma::api::client::error::ErrorKind::$variant, $crate::err_log!(buf, $level, $($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }}; (Request($variant:ident($($args:tt)+))) => { $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::$variant, + $crate::ruma::api::client::error::ErrorKind::$variant, $crate::format_maybe!($($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }; @@ -113,7 +113,7 @@ macro_rules! err_log { ($out:ident, $level:ident, $($fields:tt)+) => {{ use std::{fmt, fmt::Write}; - use ::tracing::{ + use $crate::tracing::{ callsite, callsite2, level_enabled, metadata, valueset, Callsite, Event, __macro_support, __tracing_log, field::{Field, ValueSet, Visit}, @@ -169,25 +169,25 @@ macro_rules! err_log { macro_rules! err_lev { (debug_warn) => { if $crate::debug::logging() { - ::tracing::Level::WARN + $crate::tracing::Level::WARN } else { - ::tracing::Level::DEBUG + $crate::tracing::Level::DEBUG } }; (debug_error) => { if $crate::debug::logging() { - ::tracing::Level::ERROR + $crate::tracing::Level::ERROR } else { - ::tracing::Level::DEBUG + $crate::tracing::Level::DEBUG } }; (warn) => { - ::tracing::Level::WARN + $crate::tracing::Level::WARN }; (error) => { - ::tracing::Level::ERROR + $crate::tracing::Level::ERROR }; } diff --git a/src/core/mod.rs b/src/core/mod.rs index 31851f4f0..e45531864 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -11,7 +11,10 @@ pub mod result; pub mod server; pub mod utils; +pub use ::http; +pub use ::ruma; pub use ::toml; +pub use ::tracing; pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; From 60010140784a286e53aa560cad6604caded4257e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 13 Sep 2024 07:40:22 +0000 Subject: [PATCH 011/245] add UnwrapInfallible to Result Signed-off-by: Jason Volk --- src/core/result.rs | 3 ++- src/core/result/unwrap_infallible.rs | 17 +++++++++++++++++ src/router/serve/unix.rs | 8 ++------ 3 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 src/core/result/unwrap_infallible.rs diff --git a/src/core/result.rs b/src/core/result.rs index 41d1d66c6..96a34b8a3 100644 --- a/src/core/result.rs +++ b/src/core/result.rs @@ -3,10 +3,11 @@ mod log_debug_err; mod log_err; mod map_expect; mod not_found; +mod unwrap_infallible; pub use self::{ debug_inspect::DebugInspect, log_debug_err::LogDebugErr, log_err::LogErr, map_expect::MapExpect, - not_found::NotFound, + not_found::NotFound, unwrap_infallible::UnwrapInfallible, }; pub type Result = std::result::Result; diff --git a/src/core/result/unwrap_infallible.rs b/src/core/result/unwrap_infallible.rs new file mode 100644 index 000000000..99309e025 --- /dev/null +++ b/src/core/result/unwrap_infallible.rs @@ -0,0 +1,17 @@ +use std::convert::Infallible; + +use super::{DebugInspect, Result}; +use crate::error; + +pub trait UnwrapInfallible { + fn unwrap_infallible(self) -> T; +} + +impl UnwrapInfallible for Result { + #[inline] + fn unwrap_infallible(self) -> T { + // SAFETY: Branchless unwrap for errors that can never happen. In debug + // mode this is asserted. + unsafe { self.debug_inspect_err(error::infallible).unwrap_unchecked() } + } +} diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index fb011f188..5df41b614 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -10,7 +10,7 @@ use axum::{ extract::{connect_info::IntoMakeServiceWithConnectInfo, Request}, Router, }; -use conduit::{debug, debug_error, error::infallible, info, trace, warn, Err, Result, Server}; +use conduit::{debug, debug_error, info, result::UnwrapInfallible, trace, warn, Err, Result, Server}; use hyper::{body::Incoming, service::service_fn}; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -62,11 +62,7 @@ async fn accept( let socket = TokioIo::new(socket); trace!(?listener, ?socket, ?remote, "accepted"); - let called = app - .call(NULL_ADDR) - .await - .inspect_err(infallible) - .expect("infallible"); + let called = app.call(NULL_ADDR).await.unwrap_infallible(); let service = move |req: Request| called.clone().oneshot(req); let handler = service_fn(service); From 946ca364e032be8ca2529099b415990262c977fd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 8 Aug 2024 17:18:30 +0000 Subject: [PATCH 012/245] Database Refactor combine service/users data w/ mod unit split sliding sync related out of service/users instrument database entry points remove increment crap from database interface de-wrap all database get() calls de-wrap all database insert() calls de-wrap all database remove() calls refactor database interface for async streaming add query key serializer for database implement Debug for result handle add query deserializer for database add deserialization trait for option handle start a stream utils suite de-wrap/asyncify/type-query count_one_time_keys() de-wrap/asyncify users count add admin query users command suite de-wrap/asyncify users exists de-wrap/partially asyncify user filter related asyncify/de-wrap users device/keys related asyncify/de-wrap user auth/misc related asyncify/de-wrap users blurhash asyncify/de-wrap account_data get; merge Data into Service partial asyncify/de-wrap uiaa; merge Data into Service partially asyncify/de-wrap transaction_ids get; merge Data into Service partially asyncify/de-wrap key_backups; merge Data into Service asyncify/de-wrap pusher service getters; merge Data into Service asyncify/de-wrap rooms alias getters/some iterators asyncify/de-wrap rooms directory getters/iterator partially asyncify/de-wrap rooms lazy-loading partially asyncify/de-wrap rooms metadata asyncify/dewrap rooms outlier asyncify/dewrap rooms pdu_metadata dewrap/partially asyncify rooms read receipt de-wrap rooms search service de-wrap/partially asyncify rooms user service partial de-wrap rooms state_compressor de-wrap rooms state_cache de-wrap room state et al de-wrap rooms timeline service additional users device/keys related de-wrap/asyncify sender asyncify services refactor database to TryFuture/TryStream refactor services for TryFuture/TryStream asyncify api handlers additional asyncification for admin module abstract stream related; support reverse streams additional stream conversions asyncify state-res related Signed-off-by: Jason Volk --- Cargo.lock | 53 +- Cargo.toml | 7 +- clippy.toml | 2 +- src/admin/Cargo.toml | 3 +- src/admin/check/commands.rs | 9 +- src/admin/debug/commands.rs | 100 +- src/admin/federation/commands.rs | 13 +- src/admin/media/commands.rs | 4 +- src/admin/processor.rs | 2 +- src/admin/query/account_data.rs | 6 +- src/admin/query/appservice.rs | 6 +- src/admin/query/globals.rs | 6 +- src/admin/query/presence.rs | 5 +- src/admin/query/pusher.rs | 2 +- src/admin/query/room_alias.rs | 13 +- src/admin/query/room_state_cache.rs | 93 +- src/admin/query/sending.rs | 9 +- src/admin/query/users.rs | 349 ++++- src/admin/room/alias.rs | 126 +- src/admin/room/commands.rs | 47 +- src/admin/room/directory.rs | 22 +- src/admin/room/info.rs | 48 +- src/admin/room/mod.rs | 6 + src/admin/room/moderation.rs | 347 +++-- src/admin/user/commands.rs | 192 +-- src/admin/utils.rs | 22 +- src/api/Cargo.toml | 2 +- src/api/client/account.rs | 182 +-- src/api/client/alias.rs | 28 +- src/api/client/backup.rs | 230 ++-- src/api/client/config.rs | 40 +- src/api/client/context.rs | 108 +- src/api/client/device.rs | 46 +- src/api/client/directory.rs | 188 +-- src/api/client/filter.rs | 25 +- src/api/client/keys.rs | 186 +-- src/api/client/membership.rs | 486 ++++--- src/api/client/message.rs | 196 +-- src/api/client/presence.rs | 14 +- src/api/client/profile.rs | 200 ++- src/api/client/push.rs | 178 +-- src/api/client/read_marker.rs | 90 +- src/api/client/relations.rs | 90 +- src/api/client/report.rs | 20 +- src/api/client/room.rs | 105 +- src/api/client/search.rs | 108 +- src/api/client/session.rs | 75 +- src/api/client/state.rs | 63 +- src/api/client/sync.rs | 1059 +++++++------- src/api/client/tag.rs | 45 +- src/api/client/threads.rs | 19 +- src/api/client/to_device.rs | 52 +- src/api/client/typing.rs | 3 +- src/api/client/unstable.rs | 168 +-- src/api/client/unversioned.rs | 3 +- src/api/client/user_directory.rs | 47 +- src/api/router.rs | 322 ++--- src/api/router/args.rs | 26 +- src/api/router/auth.rs | 10 +- src/api/router/handler.rs | 38 +- src/api/router/response.rs | 9 +- src/api/server/backfill.rs | 84 +- src/api/server/event.rs | 39 +- src/api/server/event_auth.rs | 33 +- src/api/server/get_missing_events.rs | 31 +- src/api/server/hierarchy.rs | 2 +- src/api/server/invite.rs | 38 +- src/api/server/make_join.rs | 89 +- src/api/server/make_leave.rs | 37 +- src/api/server/openid.rs | 5 +- src/api/server/query.rs | 36 +- src/api/server/send.rs | 196 +-- src/api/server/send_join.rs | 71 +- src/api/server/send_leave.rs | 20 +- src/api/server/state.rs | 65 +- src/api/server/state_ids.rs | 37 +- src/api/server/user.rs | 51 +- src/core/Cargo.toml | 1 + src/core/error/mod.rs | 4 +- src/core/pdu/mod.rs | 43 +- src/core/result/log_debug_err.rs | 18 +- src/core/result/log_err.rs | 20 +- src/core/utils/algorithm.rs | 25 - src/core/utils/mod.rs | 32 +- src/core/utils/set.rs | 47 + src/core/utils/stream/cloned.rs | 20 + src/core/utils/stream/expect.rs | 17 + src/core/utils/stream/ignore.rs | 21 + src/core/utils/stream/iter_stream.rs | 27 + src/core/utils/stream/mod.rs | 13 + src/core/utils/stream/ready.rs | 109 ++ src/core/utils/stream/try_ready.rs | 35 + src/core/utils/tests.rs | 130 ++ src/database/Cargo.toml | 3 + src/database/database.rs | 2 +- src/database/de.rs | 261 ++++ src/database/deserialized.rs | 34 + src/database/engine.rs | 2 +- src/database/handle.rs | 89 +- src/database/iter.rs | 110 -- src/database/keyval.rs | 83 ++ src/database/map.rs | 262 ++-- src/database/map/count.rs | 36 + src/database/map/keys.rs | 21 + src/database/map/keys_from.rs | 49 + src/database/map/keys_prefix.rs | 54 + src/database/map/rev_keys.rs | 21 + src/database/map/rev_keys_from.rs | 49 + src/database/map/rev_keys_prefix.rs | 54 + src/database/map/rev_stream.rs | 29 + src/database/map/rev_stream_from.rs | 68 + src/database/map/rev_stream_prefix.rs | 74 + src/database/map/stream.rs | 28 + src/database/map/stream_from.rs | 68 + src/database/map/stream_prefix.rs | 74 + src/database/mod.rs | 28 +- src/database/ser.rs | 315 +++++ src/database/slice.rs | 57 - src/database/stream.rs | 122 ++ src/database/stream/items.rs | 44 + src/database/stream/items_rev.rs | 44 + src/database/stream/keys.rs | 44 + src/database/stream/keys_rev.rs | 44 + src/database/util.rs | 12 + src/service/Cargo.toml | 2 +- src/service/account_data/data.rs | 152 --- src/service/account_data/mod.rs | 164 ++- src/service/admin/console.rs | 2 +- src/service/admin/create.rs | 2 +- src/service/admin/grant.rs | 216 +-- src/service/admin/mod.rs | 104 +- src/service/appservice/data.rs | 28 +- src/service/appservice/mod.rs | 49 +- src/service/emergency/mod.rs | 30 +- src/service/globals/data.rs | 121 +- src/service/globals/migrations.rs | 739 +++------- src/service/globals/mod.rs | 8 +- src/service/key_backups/data.rs | 346 ----- src/service/key_backups/mod.rs | 336 ++++- src/service/manager.rs | 2 +- src/service/media/data.rs | 102 +- src/service/media/migrations.rs | 33 +- src/service/media/mod.rs | 15 +- src/service/media/preview.rs | 8 +- src/service/media/thumbnail.rs | 4 +- src/service/mod.rs | 1 + src/service/presence/data.rs | 113 +- src/service/presence/mod.rs | 63 +- src/service/presence/presence.rs | 12 +- src/service/pusher/data.rs | 77 -- src/service/pusher/mod.rs | 124 +- src/service/resolver/actual.rs | 6 +- src/service/rooms/alias/data.rs | 125 -- src/service/rooms/alias/mod.rs | 147 +- src/service/rooms/auth_chain/data.rs | 21 +- src/service/rooms/auth_chain/mod.rs | 45 +- src/service/rooms/directory/data.rs | 39 - src/service/rooms/directory/mod.rs | 40 +- src/service/rooms/event_handler/mod.rs | 1202 ++++++++-------- .../rooms/event_handler/parse_incoming_pdu.rs | 6 +- src/service/rooms/lazy_loading/data.rs | 65 - src/service/rooms/lazy_loading/mod.rs | 112 +- src/service/rooms/metadata/data.rs | 110 -- src/service/rooms/metadata/mod.rs | 95 +- src/service/rooms/outlier/data.rs | 42 - src/service/rooms/outlier/mod.rs | 55 +- src/service/rooms/pdu_metadata/data.rs | 80 +- src/service/rooms/pdu_metadata/mod.rs | 177 +-- src/service/rooms/read_receipt/data.rs | 152 +-- src/service/rooms/read_receipt/mod.rs | 49 +- src/service/rooms/search/data.rs | 79 +- src/service/rooms/search/mod.rs | 17 +- src/service/rooms/short/data.rs | 198 ++- src/service/rooms/short/mod.rs | 36 +- src/service/rooms/spaces/mod.rs | 174 +-- src/service/rooms/state/data.rs | 71 +- src/service/rooms/state/mod.rs | 272 ++-- src/service/rooms/state_accessor/data.rs | 156 +-- src/service/rooms/state_accessor/mod.rs | 356 ++--- src/service/rooms/state_cache/data.rs | 646 ++------- src/service/rooms/state_cache/mod.rs | 471 +++++-- src/service/rooms/state_compressor/data.rs | 20 +- src/service/rooms/state_compressor/mod.rs | 85 +- src/service/rooms/threads/data.rs | 78 +- src/service/rooms/threads/mod.rs | 38 +- src/service/rooms/timeline/data.rs | 333 +++-- src/service/rooms/timeline/mod.rs | 656 +++++---- src/service/rooms/typing/mod.rs | 33 +- src/service/rooms/user/data.rs | 146 +- src/service/rooms/user/mod.rs | 46 +- src/service/sending/data.rs | 164 ++- src/service/sending/mod.rs | 113 +- src/service/sending/sender.rs | 278 ++-- src/service/server_keys/mod.rs | 26 +- src/service/services.rs | 4 +- src/service/sync/mod.rs | 233 ++++ src/service/transaction_ids/data.rs | 44 - src/service/transaction_ids/mod.rs | 44 +- src/service/uiaa/data.rs | 87 -- src/service/uiaa/mod.rs | 315 +++-- src/service/updates/mod.rs | 90 +- src/service/users/data.rs | 1098 --------------- src/service/users/mod.rs | 1213 +++++++++++------ 203 files changed, 12032 insertions(+), 10539 deletions(-) delete mode 100644 src/core/utils/algorithm.rs create mode 100644 src/core/utils/set.rs create mode 100644 src/core/utils/stream/cloned.rs create mode 100644 src/core/utils/stream/expect.rs create mode 100644 src/core/utils/stream/ignore.rs create mode 100644 src/core/utils/stream/iter_stream.rs create mode 100644 src/core/utils/stream/mod.rs create mode 100644 src/core/utils/stream/ready.rs create mode 100644 src/core/utils/stream/try_ready.rs create mode 100644 src/database/de.rs create mode 100644 src/database/deserialized.rs delete mode 100644 src/database/iter.rs create mode 100644 src/database/keyval.rs create mode 100644 src/database/map/count.rs create mode 100644 src/database/map/keys.rs create mode 100644 src/database/map/keys_from.rs create mode 100644 src/database/map/keys_prefix.rs create mode 100644 src/database/map/rev_keys.rs create mode 100644 src/database/map/rev_keys_from.rs create mode 100644 src/database/map/rev_keys_prefix.rs create mode 100644 src/database/map/rev_stream.rs create mode 100644 src/database/map/rev_stream_from.rs create mode 100644 src/database/map/rev_stream_prefix.rs create mode 100644 src/database/map/stream.rs create mode 100644 src/database/map/stream_from.rs create mode 100644 src/database/map/stream_prefix.rs create mode 100644 src/database/ser.rs delete mode 100644 src/database/slice.rs create mode 100644 src/database/stream.rs create mode 100644 src/database/stream/items.rs create mode 100644 src/database/stream/items_rev.rs create mode 100644 src/database/stream/keys.rs create mode 100644 src/database/stream/keys_rev.rs delete mode 100644 src/service/account_data/data.rs delete mode 100644 src/service/key_backups/data.rs delete mode 100644 src/service/pusher/data.rs delete mode 100644 src/service/rooms/alias/data.rs delete mode 100644 src/service/rooms/directory/data.rs delete mode 100644 src/service/rooms/lazy_loading/data.rs delete mode 100644 src/service/rooms/metadata/data.rs delete mode 100644 src/service/rooms/outlier/data.rs create mode 100644 src/service/sync/mod.rs delete mode 100644 src/service/transaction_ids/data.rs delete mode 100644 src/service/uiaa/data.rs delete mode 100644 src/service/users/data.rs diff --git a/Cargo.lock b/Cargo.lock index 6386f9685..08e0498aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -626,10 +626,11 @@ dependencies = [ "clap", "conduit_api", "conduit_core", + "conduit_database", "conduit_macros", "conduit_service", "const-str", - "futures-util", + "futures", "log", "ruma", "serde_json", @@ -652,7 +653,7 @@ dependencies = [ "conduit_database", "conduit_service", "const-str", - "futures-util", + "futures", "hmac", "http", "http-body-util", @@ -689,6 +690,7 @@ dependencies = [ "cyborgtime", "either", "figment", + "futures", "hardened_malloc-rs", "http", "http-body-util", @@ -726,8 +728,11 @@ version = "0.4.7" dependencies = [ "conduit_core", "const-str", + "futures", "log", "rust-rocksdb-uwu", + "serde", + "serde_json", "tokio", "tracing", ] @@ -784,7 +789,7 @@ dependencies = [ "conduit_core", "conduit_database", "const-str", - "futures-util", + "futures", "hickory-resolver", "http", "image", @@ -1283,6 +1288,20 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1345,6 +1364,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -2953,7 +2973,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "assign", "js_int", @@ -2975,7 +2995,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "ruma-common", @@ -2987,7 +3007,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "as_variant", "assign", @@ -3010,7 +3030,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "as_variant", "base64 0.22.1", @@ -3040,7 +3060,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3064,7 +3084,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "bytes", "http", @@ -3082,7 +3102,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "thiserror", @@ -3091,7 +3111,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "ruma-common", @@ -3101,7 +3121,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "cfg-if", "once_cell", @@ -3117,7 +3137,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "ruma-common", @@ -3129,7 +3149,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "headers", "http", @@ -3142,7 +3162,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3158,8 +3178,9 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ + "futures-util", "itertools 0.12.1", "js_int", "ruma-common", diff --git a/Cargo.toml b/Cargo.toml index b75c49757..3bfb3bc81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -210,9 +210,10 @@ features = [ "string", ] -[workspace.dependencies.futures-util] +[workspace.dependencies.futures] version = "0.3.30" default-features = false +features = ["std"] [workspace.dependencies.tokio] version = "1.40.0" @@ -314,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "9900d0676564883cfade556d6e8da2a2c9061efd" +rev = "e7db44989d68406393270d3a91815597385d3acb" features = [ "compat", "rand", @@ -463,7 +464,6 @@ version = "1.0.36" [workspace.dependencies.proc-macro2] version = "1.0.89" - # # Patches # @@ -828,6 +828,7 @@ missing_panics_doc = { level = "allow", priority = 1 } module_name_repetitions = { level = "allow", priority = 1 } no_effect_underscore_binding = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 } +single_match_else = { level = "allow", priority = 1 } struct_field_names = { level = "allow", priority = 1 } unnecessary_wraps = { level = "allow", priority = 1 } unused_async = { level = "allow", priority = 1 } diff --git a/clippy.toml b/clippy.toml index c942b93c7..08641fcc1 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,6 +2,6 @@ array-size-threshold = 4096 cognitive-complexity-threshold = 94 # TODO reduce me ALARA excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5 future-size-threshold = 7745 # TODO reduce me ALARA -stack-size-threshold = 144000 # reduce me ALARA +stack-size-threshold = 196608 # reduce me ALARA too-many-lines-threshold = 700 # TODO reduce me to <= 100 type-complexity-threshold = 250 # reduce me to ~200 diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index d756b3cbd..f5cab4496 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -29,10 +29,11 @@ release_max_log_level = [ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true +conduit-database.workspace = true conduit-macros.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true log.workspace = true ruma.workspace = true serde_json.workspace = true diff --git a/src/admin/check/commands.rs b/src/admin/check/commands.rs index 0a9830464..88fca462f 100644 --- a/src/admin/check/commands.rs +++ b/src/admin/check/commands.rs @@ -1,5 +1,6 @@ use conduit::Result; use conduit_macros::implement; +use futures::StreamExt; use ruma::events::room::message::RoomMessageEventContent; use crate::Command; @@ -10,14 +11,12 @@ use crate::Command; #[implement(Command, params = "<'_>")] pub(super) async fn check_all_users(&self) -> Result { let timer = tokio::time::Instant::now(); - let results = self.services.users.db.iter(); + let users = self.services.users.iter().collect::>().await; let query_time = timer.elapsed(); - let users = results.collect::>(); - let total = users.len(); - let err_count = users.iter().filter(|user| user.is_err()).count(); - let ok_count = users.iter().filter(|user| user.is_ok()).count(); + let err_count = users.iter().filter(|_user| false).count(); + let ok_count = users.iter().filter(|_user| true).count(); let message = format!( "Database query completed in {query_time:?}:\n\n```\nTotal entries: {total:?}\nFailure/Invalid user count: \ diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 2d9670064..65c9bc712 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -7,6 +7,7 @@ use std::{ use api::client::validate_and_add_event_id; use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, events::room::message::RoomMessageEventContent, @@ -27,7 +28,7 @@ pub(super) async fn echo(&self, message: Vec) -> Result) -> Result { let event_id = Arc::::from(event_id); - if let Some(event) = self.services.rooms.timeline.get_pdu_json(&event_id)? { + if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await { let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) @@ -43,7 +44,8 @@ pub(super) async fn get_auth_chain(&self, event_id: Box) -> Result) -> Result { + Ok(json) => { let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); Ok(RoomMessageEventContent::notice_markdown(format!( "{}\n```json\n{}\n```", @@ -109,7 +114,7 @@ pub(super) async fn get_pdu(&self, event_id: Box) -> Result Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), + Err(_) => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), } } @@ -157,7 +162,8 @@ pub(super) async fn get_remote_pdu_list( .send_message(RoomMessageEventContent::text_plain(format!( "Failed to get remote PDU, ignoring error: {e}" ))) - .await; + .await + .ok(); warn!("Failed to get remote PDU, ignoring error: {e}"); } else { success_count = success_count.saturating_add(1); @@ -215,7 +221,9 @@ pub(super) async fn get_remote_pdu( .services .rooms .event_handler - .parse_incoming_pdu(&response.pdu); + .parse_incoming_pdu(&response.pdu) + .await; + let (event_id, value, room_id) = match parsed_result { Ok(t) => t, Err(e) => { @@ -333,9 +341,12 @@ pub(super) async fn ping(&self, server: Box) -> Result Result { // Force E2EE device list updates for all users - for user_id in self.services.users.iter().filter_map(Result::ok) { - self.services.users.mark_device_key_update(&user_id)?; - } + self.services + .users + .stream() + .for_each(|user_id| self.services.users.mark_device_key_update(user_id)) + .await; + Ok(RoomMessageEventContent::text_plain( "Marked all devices for all users as having new keys to update", )) @@ -470,7 +481,8 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box) -> Result) -> Result) -> Result) -> Result> = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); @@ -554,13 +571,21 @@ pub(super) async fn force_set_room_state_from_server( let mut events = Vec::with_capacity(remote_state_response.pdus.len()); for pdu in remote_state_response.pdus.clone() { - events.push(match self.services.rooms.event_handler.parse_incoming_pdu(&pdu) { - Ok(t) => t, - Err(e) => { - warn!("Could not parse PDU, ignoring: {e}"); - continue; + events.push( + match self + .services + .rooms + .event_handler + .parse_incoming_pdu(&pdu) + .await + { + Ok(t) => t, + Err(e) => { + warn!("Could not parse PDU, ignoring: {e}"); + continue; + }, }, - }); + ); } info!("Fetching required signing keys for all the state events we got"); @@ -587,13 +612,16 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); + if let Some(state_key) = &pdu.state_key { let shortstatekey = self .services .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; + state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -611,7 +639,7 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); } let new_room_state = self @@ -626,7 +654,8 @@ pub(super) async fn force_set_room_state_from_server( .services .rooms .state_compressor - .save_state(room_id.clone().as_ref(), new_room_state)?; + .save_state(room_id.clone().as_ref(), new_room_state) + .await?; let state_lock = self.services.rooms.state.mutex.lock(&room_id).await; self.services @@ -642,7 +671,8 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .state_cache - .update_joined_count(&room_id)?; + .update_joined_count(&room_id) + .await; drop(state_lock); @@ -656,7 +686,7 @@ pub(super) async fn get_signing_keys( &self, server_name: Option>, _cached: bool, ) -> Result { let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); - let signing_keys = self.services.globals.signing_keys_for(&server_name)?; + let signing_keys = self.services.globals.signing_keys_for(&server_name).await?; Ok(RoomMessageEventContent::notice_markdown(format!( "```rs\n{signing_keys:#?}\n```" @@ -674,7 +704,7 @@ pub(super) async fn get_verify_keys( if cached { writeln!(out, "| Key ID | VerifyKey |")?; writeln!(out, "| --- | --- |")?; - for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name)? { + for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name).await? { writeln!(out, "| {key_id} | {verify_key:?} |")?; } diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 8917a46b9..ce95ac01b 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -1,19 +1,20 @@ use std::fmt::Write; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId}; use crate::{admin_command, escape_html, get_room_info}; #[admin_command] pub(super) async fn disable_room(&self, room_id: Box) -> Result { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); Ok(RoomMessageEventContent::text_plain("Room disabled.")) } #[admin_command] pub(super) async fn enable_room(&self, room_id: Box) -> Result { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); Ok(RoomMessageEventContent::text_plain("Room enabled.")) } @@ -85,7 +86,7 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box) -> Result< )); } - if !self.services.users.exists(&user_id)? { + if !self.services.users.exists(&user_id).await { return Ok(RoomMessageEventContent::text_plain( "Remote user does not exist in our database.", )); @@ -96,9 +97,9 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box) -> Result< .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .map(|room_id| get_room_info(self.services, &room_id)) - .collect(); + .then(|room_id| get_room_info(self.services, room_id)) + .collect() + .await; if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); diff --git a/src/admin/media/commands.rs b/src/admin/media/commands.rs index 3c4bf2ef8..82ac162eb 100644 --- a/src/admin/media/commands.rs +++ b/src/admin/media/commands.rs @@ -36,7 +36,7 @@ pub(super) async fn delete( let mut mxc_urls = Vec::with_capacity(4); // parsing the PDU for any MXC URLs begins here - if let Some(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id)? { + if let Ok(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id).await { if let Some(content_key) = event_json.get("content") { debug!("Event ID has \"content\"."); let content_obj = content_key.as_object(); @@ -300,7 +300,7 @@ pub(super) async fn delete_all_from_server( #[admin_command] pub(super) async fn get_file_info(&self, mxc: OwnedMxcUri) -> Result { let mxc: Mxc<'_> = mxc.as_str().try_into()?; - let metadata = self.services.media.get_metadata(&mxc); + let metadata = self.services.media.get_metadata(&mxc).await; Ok(RoomMessageEventContent::notice_markdown(format!("```\n{metadata:#?}\n```"))) } diff --git a/src/admin/processor.rs b/src/admin/processor.rs index 4f60f56e9..3c1895ffd 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -17,7 +17,7 @@ use conduit::{ utils::string::{collect_stream, common_prefix}, warn, Error, Result, }; -use futures_util::future::FutureExt; +use futures::future::FutureExt; use ruma::{ events::{ relation::InReplyTo, diff --git a/src/admin/query/account_data.rs b/src/admin/query/account_data.rs index e18c298a3..896bf95cf 100644 --- a/src/admin/query/account_data.rs +++ b/src/admin/query/account_data.rs @@ -44,7 +44,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .changes_since(room_id.as_deref(), &user_id, since)?; + .changes_since(room_id.as_deref(), &user_id, since) + .await?; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -59,7 +60,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .get(room_id.as_deref(), &user_id, kind)?; + .get(room_id.as_deref(), &user_id, kind) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 683c228f7..4b97ef4eb 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -29,7 +29,9 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> let results = services .appservice .db - .get_registration(appservice_id.as_ref()); + .get_registration(appservice_id.as_ref()) + .await; + let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -38,7 +40,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> }, AppserviceCommand::All => { let timer = tokio::time::Instant::now(); - let results = services.appservice.all(); + let results = services.appservice.all().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 5f271c2c4..150a213cd 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -29,7 +29,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - match subcommand { GlobalsCommand::DatabaseVersion => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.database_version(); + let results = services.globals.db.database_version().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -47,7 +47,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - }, GlobalsCommand::LastCheckForUpdatesId => { let timer = tokio::time::Instant::now(); - let results = services.updates.last_check_for_updates_id(); + let results = services.updates.last_check_for_updates_id().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -67,7 +67,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - origin, } => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.verify_keys_for(&origin); + let results = services.globals.db.verify_keys_for(&origin).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/presence.rs b/src/admin/query/presence.rs index 145ecd9b1..6189270cc 100644 --- a/src/admin/query/presence.rs +++ b/src/admin/query/presence.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, UserId}; use crate::Command; @@ -30,7 +31,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.presence.db.get_presence(&user_id)?; + let results = services.presence.db.get_presence(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -42,7 +43,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) } => { let timer = tokio::time::Instant::now(); let results = services.presence.db.presence_since(since); - let presence_since: Vec<(_, _, _)> = results.collect(); + let presence_since: Vec<(_, _, _)> = results.collect().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/pusher.rs b/src/admin/query/pusher.rs index 637c57b65..a1bd32f99 100644 --- a/src/admin/query/pusher.rs +++ b/src/admin/query/pusher.rs @@ -21,7 +21,7 @@ pub(super) async fn process(subcommand: PusherCommand, context: &Command<'_>) -> user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.pusher.get_pushers(&user_id)?; + let results = services.pusher.get_pushers(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_alias.rs b/src/admin/query/room_alias.rs index 1809e26a0..05fac42cc 100644 --- a/src/admin/query/room_alias.rs +++ b/src/admin/query/room_alias.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; use crate::Command; @@ -31,7 +32,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) alias, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.resolve_local_alias(&alias); + let results = services.rooms.alias.resolve_local_alias(&alias).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -43,7 +44,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) } => { let timer = tokio::time::Instant::now(); let results = services.rooms.alias.local_aliases_for_room(&room_id); - let aliases: Vec<_> = results.collect(); + let aliases: Vec<_> = results.collect().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -52,8 +53,12 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) }, RoomAliasCommand::AllLocalAliases => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.all_local_aliases(); - let aliases: Vec<_> = results.collect(); + let aliases = services + .rooms + .alias + .all_local_aliases() + .collect::>() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_state_cache.rs b/src/admin/query/room_state_cache.rs index 4215cf8d6..e32517fb1 100644 --- a/src/admin/query/room_state_cache.rs +++ b/src/admin/query/room_state_cache.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId}; use crate::Command; @@ -86,7 +87,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let result = services.rooms.state_cache.server_in_room(&server, &room_id); + let result = services + .rooms + .state_cache + .server_in_room(&server, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -97,7 +102,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.room_servers(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_servers(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -108,7 +119,13 @@ pub(super) async fn process( server, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.server_rooms(&server).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .server_rooms(&server) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -119,7 +136,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.room_members(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_members(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -134,7 +157,9 @@ pub(super) async fn process( .rooms .state_cache .local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -149,7 +174,9 @@ pub(super) async fn process( .rooms .state_cache .active_local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -160,7 +187,7 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_joined_count(&room_id); + let results = services.rooms.state_cache.room_joined_count(&room_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -171,7 +198,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_invited_count(&room_id); + let results = services + .rooms + .state_cache + .room_invited_count(&room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -182,11 +213,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services + let results: Vec<_> = services .rooms .state_cache .room_useroncejoined(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -197,11 +230,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services + let results: Vec<_> = services .rooms .state_cache .room_members_invited(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -216,7 +251,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_invite_count(&room_id, &user_id); + .get_invite_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -231,7 +267,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_left_count(&room_id, &user_id); + .get_left_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -242,7 +279,13 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_joined(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_joined(&user_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -253,7 +296,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_invited(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_invited(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -264,7 +312,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_left(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_left(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -276,7 +329,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.invite_state(&user_id, &room_id); + let results = services + .rooms + .state_cache + .invite_state(&user_id, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/sending.rs b/src/admin/query/sending.rs index 6d54bddfd..eaab1f5ee 100644 --- a/src/admin/query/sending.rs +++ b/src/admin/query/sending.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId}; use service::sending::Destination; @@ -68,7 +69,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - SendingCommand::ActiveRequests => { let timer = tokio::time::Instant::now(); let results = services.sending.db.active_requests(); - let active_requests: Result> = results.collect(); + let active_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -133,7 +134,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let queued_requests = results.collect::>>(); + let queued_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -199,7 +200,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let active_requests = results.collect::>>(); + let active_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -210,7 +211,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - server_name, } => { let timer = tokio::time::Instant::now(); - let results = services.sending.db.get_latest_educount(&server_name); + let results = services.sending.db.get_latest_educount(&server_name).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/users.rs b/src/admin/query/users.rs index fee12fbfc..0792e4840 100644 --- a/src/admin/query/users.rs +++ b/src/admin/query/users.rs @@ -1,29 +1,344 @@ use clap::Subcommand; use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::stream::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedDeviceId, OwnedRoomId, OwnedUserId}; -use crate::Command; +use crate::{admin_command, admin_command_dispatch}; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/users.rs pub(crate) enum UsersCommand { - Iter, + CountUsers, + + IterUsers, + + PasswordHash { + user_id: OwnedUserId, + }, + + ListDevices { + user_id: OwnedUserId, + }, + + ListDevicesMetadata { + user_id: OwnedUserId, + }, + + GetDeviceMetadata { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDevicesVersion { + user_id: OwnedUserId, + }, + + CountOneTimeKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDeviceKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetUserSigningKey { + user_id: OwnedUserId, + }, + + GetMasterKey { + user_id: OwnedUserId, + }, + + GetToDeviceEvents { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetLatestBackup { + user_id: OwnedUserId, + }, + + GetLatestBackupVersion { + user_id: OwnedUserId, + }, + + GetBackupAlgorithm { + user_id: OwnedUserId, + version: String, + }, + + GetAllBackups { + user_id: OwnedUserId, + version: String, + }, + + GetRoomBackups { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + }, + + GetBackupSession { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + session_id: String, + }, +} + +#[admin_command] +async fn get_backup_session( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, session_id: String, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_session(&user_id, &version, &room_id, &session_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_room_backups( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_room(&user_id, &version, &room_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_all_backups(&self, user_id: OwnedUserId, version: String) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_all(&user_id, &version).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_backup_algorithm(&self, user_id: OwnedUserId, version: String) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_backup(&user_id, &version) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup_version(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_latest_backup_version(&user_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_latest_backup(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) } -/// All the getters and iterators in key_value/users.rs -pub(super) async fn process(subcommand: UsersCommand, context: &Command<'_>) -> Result { - let services = context.services; +#[admin_command] +async fn iter_users(&self) -> Result { + let timer = tokio::time::Instant::now(); + let result: Vec = self.services.users.stream().map(Into::into).collect().await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn count_users(&self) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.count().await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn password_hash(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.password_hash(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_device_ids(&user_id) + .map(ToOwned::to_owned) + .collect::>() + .await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices_metadata(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_devices_metadata(&user_id) + .collect::>() + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_metadata(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let device = self + .services + .users + .get_device_metadata(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn get_devices_version(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let device = self.services.users.get_devicelist_version(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn count_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .count_one_time_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_device_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_user_signing_key(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.get_user_signing_key(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_master_key(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_master_key(None, &user_id, &|_| true) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} - match subcommand { - UsersCommand::Iter => { - let timer = tokio::time::Instant::now(); - let results = services.users.db.iter(); - let users = results.collect::>(); - let query_time = timer.elapsed(); +#[admin_command] +async fn get_to_device_events( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_to_device_events(&user_id, &device_id) + .collect::>() + .await; + let query_time = timer.elapsed(); - Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{users:#?}\n```" - ))) - }, - } + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) } diff --git a/src/admin/room/alias.rs b/src/admin/room/alias.rs index 415b8a083..34b6c42ec 100644 --- a/src/admin/room/alias.rs +++ b/src/admin/room/alias.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use crate::{escape_html, Command}; @@ -66,8 +67,8 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> force, room_id, .. - } => match (force, services.rooms.alias.resolve_local_alias(&room_alias)) { - (true, Ok(Some(id))) => match services + } => match (force, services.rooms.alias.resolve_local_alias(&room_alias).await) { + (true, Ok(id)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -77,10 +78,10 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> ))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!( + (false, Ok(id)) => Ok(RoomMessageEventContent::text_plain(format!( "Refusing to overwrite in use alias for {id}, use -f or --force to overwrite" ))), - (_, Ok(None)) => match services + (_, Err(_)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -88,12 +89,11 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (_, Err(err)) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), }, RoomAliasCommand::Remove { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => match services + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => match services .rooms .alias .remove_alias(&room_alias, server_user) @@ -102,15 +102,13 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {id}"))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::Which { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::List { .. @@ -125,63 +123,59 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> .rooms .alias .local_aliases_for_room(&room_id) - .collect::, _>>(); - match aliases { - Ok(aliases) => { - let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "- {alias}").expect("should be able to write to string buffer"); - output - }); - - let html_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "
  • {}
  • ", escape_html(alias.as_ref())) - .expect("should be able to write to string buffer"); - output - }); - - let plain = format!("Aliases for {room_id}:\n{plain_list}"); - let html = format!("Aliases for {room_id}:\n
      {html_list}
    "); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {err}"))), - } + .map(Into::into) + .collect::>() + .await; + + let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "- {alias}").expect("should be able to write to string buffer"); + output + }); + + let html_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "
  • {}
  • ", escape_html(alias.as_ref())) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases for {room_id}:\n{plain_list}"); + let html = format!("Aliases for {room_id}:\n
      {html_list}
    "); + Ok(RoomMessageEventContent::text_html(plain, html)) } else { let aliases = services .rooms .alias .all_local_aliases() - .collect::, _>>(); - match aliases { - Ok(aliases) => { - let server_name = services.globals.server_name(); - let plain_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!(output, "- `{alias}` -> #{id}:{server_name}") - .expect("should be able to write to string buffer"); - output - }); - - let html_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!( - output, - "
  • {} -> #{}:{}
  • ", - escape_html(alias.as_ref()), - escape_html(id.as_ref()), - server_name - ) - .expect("should be able to write to string buffer"); - output - }); - - let plain = format!("Aliases:\n{plain_list}"); - let html = format!("Aliases:\n
      {html_list}
    "); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))), - } + .map(|(room_id, localpart)| (room_id.into(), localpart.into())) + .collect::>() + .await; + + let server_name = services.globals.server_name(); + let plain_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!(output, "- `{alias}` -> #{id}:{server_name}") + .expect("should be able to write to string buffer"); + output + }); + + let html_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!( + output, + "
  • {} -> #{}:{}
  • ", + escape_html(alias.as_ref()), + escape_html(id), + server_name + ) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases:\n{plain_list}"); + let html = format!("Aliases:\n
      {html_list}
    "); + Ok(RoomMessageEventContent::text_html(plain, html)) } }, } diff --git a/src/admin/room/commands.rs b/src/admin/room/commands.rs index 2adfa7d73..1c90a9983 100644 --- a/src/admin/room/commands.rs +++ b/src/admin/room/commands.rs @@ -1,11 +1,12 @@ use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; use crate::{admin_command, get_room_info, PAGE_SIZE}; #[admin_command] pub(super) async fn list_rooms( - &self, page: Option, exclude_disabled: bool, exclude_banned: bool, no_details: bool, + &self, page: Option, _exclude_disabled: bool, _exclude_banned: bool, no_details: bool, ) -> Result { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); @@ -14,37 +15,12 @@ pub(super) async fn list_rooms( .rooms .metadata .iter_ids() - .filter_map(|room_id| { - room_id - .ok() - .filter(|room_id| { - if exclude_disabled - && self - .services - .rooms - .metadata - .is_disabled(room_id) - .unwrap_or(false) - { - return false; - } + //.filter(|room_id| async { !exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await }) + //.filter(|room_id| async { !exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await }) + .then(|room_id| get_room_info(self.services, room_id)) + .collect::>() + .await; - if exclude_banned - && self - .services - .rooms - .metadata - .is_banned(room_id) - .unwrap_or(false) - { - return false; - } - - true - }) - .map(|room_id| get_room_info(self.services, &room_id)) - }) - .collect::>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); @@ -74,3 +50,10 @@ pub(super) async fn list_rooms( Ok(RoomMessageEventContent::notice_markdown(output_plain)) } + +#[admin_command] +pub(super) async fn exists(&self, room_id: OwnedRoomId) -> Result { + let result = self.services.rooms.metadata.exists(&room_id).await; + + Ok(RoomMessageEventContent::notice_markdown(format!("{result}"))) +} diff --git a/src/admin/room/directory.rs b/src/admin/room/directory.rs index 7bba2eb7b..7ccdea6f0 100644 --- a/src/admin/room/directory.rs +++ b/src/admin/room/directory.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use crate::{escape_html, get_room_info, Command, PAGE_SIZE}; @@ -31,15 +32,15 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> match command { RoomDirectoryCommand::Publish { room_id, - } => match services.rooms.directory.set_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room published")) }, RoomDirectoryCommand::Unpublish { room_id, - } => match services.rooms.directory.set_not_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_not_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room unpublished")) }, RoomDirectoryCommand::List { page, @@ -50,9 +51,10 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> .rooms .directory .public_rooms() - .filter_map(Result::ok) - .map(|id: OwnedRoomId| get_room_info(services, &id)) - .collect::>(); + .then(|room_id| get_room_info(services, room_id)) + .collect::>() + .await; + rooms.sort_by_key(|r| r.1); rooms.reverse(); diff --git a/src/admin/room/info.rs b/src/admin/room/info.rs index d17a29247..fc0619e33 100644 --- a/src/admin/room/info.rs +++ b/src/admin/room/info.rs @@ -1,5 +1,6 @@ use clap::Subcommand; -use conduit::Result; +use conduit::{utils::ReadyExt, Result}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use crate::{admin_command, admin_command_dispatch}; @@ -32,46 +33,42 @@ async fn list_joined_members(&self, room_id: Box, local_only: bool) -> R .rooms .state_accessor .get_name(&room_id) - .ok() - .flatten() - .unwrap_or_else(|| room_id.to_string()); + .await + .unwrap_or_else(|_| room_id.to_string()); - let members = self + let member_info: Vec<_> = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|member| { + .ready_filter(|user_id| { if local_only { - member - .ok() - .filter(|user| self.services.globals.user_is_local(user)) + self.services.globals.user_is_local(user_id) } else { - member.ok() + true } - }); - - let member_info = members - .into_iter() - .map(|user_id| { - ( - user_id.clone(), + }) + .filter_map(|user_id| async move { + let user_id = user_id.to_owned(); + Some(( self.services .users .displayname(&user_id) - .unwrap_or(None) - .unwrap_or_else(|| user_id.to_string()), - ) + .await + .unwrap_or_else(|_| user_id.to_string()), + user_id, + )) }) - .collect::>(); + .collect() + .await; let output_plain = format!( "{} Members in Room \"{}\":\n```\n{}\n```", member_info.len(), room_name, member_info - .iter() - .map(|(mxid, displayname)| format!("{mxid} | {displayname}")) + .into_iter() + .map(|(displayname, mxid)| format!("{mxid} | {displayname}")) .collect::>() .join("\n") ); @@ -81,11 +78,12 @@ async fn list_joined_members(&self, room_id: Box, local_only: bool) -> R #[admin_command] async fn view_room_topic(&self, room_id: Box) -> Result { - let Some(room_topic) = self + let Ok(room_topic) = self .services .rooms .state_accessor - .get_room_topic(&room_id)? + .get_room_topic(&room_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set.")); }; diff --git a/src/admin/room/mod.rs b/src/admin/room/mod.rs index 64d2af452..8c6cbeaae 100644 --- a/src/admin/room/mod.rs +++ b/src/admin/room/mod.rs @@ -6,6 +6,7 @@ mod moderation; use clap::Subcommand; use conduit::Result; +use ruma::OwnedRoomId; use self::{ alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand, @@ -49,4 +50,9 @@ pub(super) enum RoomCommand { #[command(subcommand)] /// - Manage the room directory Directory(RoomDirectoryCommand), + + /// - Check if we know about a room + Exists { + room_id: OwnedRoomId, + }, } diff --git a/src/admin/room/moderation.rs b/src/admin/room/moderation.rs index 70d8486b4..9a772da48 100644 --- a/src/admin/room/moderation.rs +++ b/src/admin/room/moderation.rs @@ -1,6 +1,11 @@ use api::client::leave_room; use clap::Subcommand; -use conduit::{debug, error, info, warn, Result}; +use conduit::{ + debug, error, info, + utils::{IterStream, ReadyExt}, + warn, Result, +}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use crate::{admin_command, admin_command_dispatch, get_room_info}; @@ -76,7 +81,7 @@ async fn ban_room( let admin_room_alias = &self.services.globals.admin_alias; - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) { return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); } @@ -95,7 +100,7 @@ async fn ban_room( debug!("Room specified is a room ID, banning room ID"); - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else if room.is_room_alias_id() { @@ -114,7 +119,13 @@ async fn ban_room( get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -138,7 +149,7 @@ async fn ban_room( } }; - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else { @@ -150,56 +161,40 @@ async fn ban_room( debug!("Making all users leave the room {}", &room); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - self.services.globals.user_is_local(local_user) - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would - // fail auth check) - && (self.services.globals.user_is_local(local_user) - // since this is a force operation, assume user is an admin - // if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, &room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during room banning: {}", &local_user, &room_id, e @@ -214,12 +209,14 @@ async fn ban_room( } // remove any local aliases, ignore errors - for ref local_alias in self + for local_alias in &self .services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) + .map(ToOwned::to_owned) + .collect::>() + .await { _ = self .services @@ -230,10 +227,10 @@ async fn ban_room( } // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); return Ok(RoomMessageEventContent::text_plain( "Room banned, removed all our local users, and disabled incoming federation with room.", )); @@ -268,7 +265,7 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu for &room in &rooms_s { match <&RoomOrAliasId>::try_from(room) { Ok(room_alias_or_id) => { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) { info!("User specified admin room in bulk ban list, ignoring"); continue; @@ -300,43 +297,48 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu if room_alias_or_id.is_room_alias_id() { match RoomAliasId::parse(room_alias_or_id) { Ok(room_alias) => { - let room_id = - if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { - room_id - } else { - debug!( - "We don't have this room alias to a room ID locally, attempting to fetch room \ - ID over federation" - ); - - match self - .services - .rooms - .alias - .resolve_alias(&room_alias, None) - .await - { - Ok((room_id, servers)) => { - debug!( - ?room_id, - ?servers, - "Got federation response fetching room ID for {room}", - ); - room_id - }, - Err(e) => { - // don't fail if force blocking - if force { - warn!("Failed to resolve room alias {room} to a room ID: {e}"); - continue; - } - - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to resolve room alias {room} to a room ID: {e}" - ))); - }, - } - }; + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { + room_id + } else { + debug!( + "We don't have this room alias to a room ID locally, attempting to fetch room ID \ + over federation" + ); + + match self + .services + .rooms + .alias + .resolve_alias(&room_alias, None) + .await + { + Ok((room_id, servers)) => { + debug!( + ?room_id, + ?servers, + "Got federation response fetching room ID for {room}", + ); + room_id + }, + Err(e) => { + // don't fail if force blocking + if force { + warn!("Failed to resolve room alias {room} to a room ID: {e}"); + continue; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }; room_ids.push(room_id); }, @@ -374,74 +376,52 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } for room_id in room_ids { - if self - .services - .rooms - .metadata - .ban_room(&room_id, true) - .is_ok() - { - debug!("Banned {room_id} successfully"); - room_ban_count = room_ban_count.saturating_add(1); - } + self.services.rooms.metadata.ban_room(&room_id, true); + + debug!("Banned {room_id} successfully"); + room_ban_count = room_ban_count.saturating_add(1); debug!("Making all users leave the room {}", &room_id); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - // since this is a force operation, assume user is an - // admin if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { - debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + + debug!("Attempting leave for user {local_user} in room {room_id}"); + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( - "Error attempting to make local user {} leave room {} during bulk room banning: {}", - &local_user, &room_id, e + "Error attempting to make local user {local_user} leave room {room_id} during bulk room \ + banning: {e}", ); + return Ok(RoomMessageEventContent::text_plain(format!( "Error attempting to make local user {} leave room {} during room banning (room is still \ banned but not removing any more users and not banning any more rooms): {}\nIf you would \ @@ -453,26 +433,26 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } // remove any local aliases, ignore errors - for ref local_alias in self - .services + self.services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) - { - _ = self - .services - .rooms - .alias - .remove_alias(local_alias, &self.services.globals.server_user) - .await; - } + .map(ToOwned::to_owned) + .for_each(|local_alias| async move { + self.services + .rooms + .alias + .remove_alias(&local_alias, &self.services.globals.server_user) + .await + .ok(); + }) + .await; // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); } } @@ -503,7 +483,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> debug!("Room specified is a room ID, unbanning room ID"); - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else if room.is_room_alias_id() { @@ -522,7 +502,13 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -546,7 +532,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> } }; - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else { @@ -557,7 +543,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> }; if enable_federation { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); } @@ -569,45 +555,42 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> #[admin_command] async fn list_banned_rooms(&self, no_details: bool) -> Result { - let rooms = self + let room_ids = self .services .rooms .metadata .list_banned_rooms() - .collect::, _>>(); + .map(Into::into) + .collect::>() + .await; - match rooms { - Ok(room_ids) => { - if room_ids.is_empty() { - return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); - } - - let mut rooms = room_ids - .into_iter() - .map(|room_id| get_room_info(self.services, &room_id)) - .collect::>(); - rooms.sort_by_key(|r| r.1); - rooms.reverse(); - - let output_plain = format!( - "Rooms Banned ({}):\n```\n{}\n```", - rooms.len(), - rooms - .iter() - .map(|(id, members, name)| if no_details { - format!("{id}") - } else { - format!("{id}\tMembers: {members}\tName: {name}") - }) - .collect::>() - .join("\n") - ); - - Ok(RoomMessageEventContent::notice_markdown(output_plain)) - }, - Err(e) => { - error!("Failed to list banned rooms: {e}"); - Ok(RoomMessageEventContent::text_plain(format!("Unable to list banned rooms: {e}"))) - }, + if room_ids.is_empty() { + return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); } + + let mut rooms = room_ids + .iter() + .stream() + .then(|room_id| get_room_info(self.services, room_id)) + .collect::>() + .await; + + rooms.sort_by_key(|r| r.1); + rooms.reverse(); + + let output_plain = format!( + "Rooms Banned ({}):\n```\n{}\n```", + rooms.len(), + rooms + .iter() + .map(|(id, members, name)| if no_details { + format!("{id}") + } else { + format!("{id}\tMembers: {members}\tName: {name}") + }) + .collect::>() + .join("\n") + ); + + Ok(RoomMessageEventContent::notice_markdown(output_plain)) } diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 20691f1a2..1b086856a 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -1,7 +1,9 @@ use std::{collections::BTreeMap, fmt::Write as _}; use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room}; -use conduit::{error, info, utils, warn, PduBuilder, Result}; +use conduit::{error, info, is_equal_to, utils, warn, PduBuilder, Result}; +use conduit_api::client::{leave_all_rooms, update_avatar_url, update_displayname}; +use futures::StreamExt; use ruma::{ events::{ room::{ @@ -25,16 +27,19 @@ const AUTO_GEN_PASSWORD_LENGTH: usize = 25; #[admin_command] pub(super) async fn list_users(&self) -> Result { - match self.services.users.list_local_users() { - Ok(users) => { - let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); - plain_msg += users.join("\n").as_str(); - plain_msg += "\n```"; + let users = self + .services + .users + .list_local_users() + .map(ToString::to_string) + .collect::>() + .await; - Ok(RoomMessageEventContent::notice_markdown(plain_msg)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(e.to_string())), - } + let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); + plain_msg += users.join("\n").as_str(); + plain_msg += "\n```"; + + Ok(RoomMessageEventContent::notice_markdown(plain_msg)) } #[admin_command] @@ -42,7 +47,7 @@ pub(super) async fn create_user(&self, username: String, password: Option )); } - self.services.users.deactivate_account(&user_id)?; + self.services.users.deactivate_account(&user_id).await?; if !no_leave_rooms { self.services @@ -175,17 +184,22 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) -> .send_message(RoomMessageEventContent::text_plain(format!( "Making {user_id} leave all rooms after deactivation..." ))) - .await; + .await + .ok(); let all_joined_rooms: Vec = self .services .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms).await?; + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms).await?; + leave_all_rooms(self.services, &user_id).await; } Ok(RoomMessageEventContent::text_plain(format!( @@ -238,15 +252,16 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut admins = Vec::new(); for username in usernames { - match parse_active_local_user_id(self.services, username) { + match parse_active_local_user_id(self.services, username).await { Ok(user_id) => { - if self.services.users.is_admin(&user_id)? && !force { + if self.services.users.is_admin(&user_id).await && !force { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "{username} is an admin and --force is not set, skipping over" ))) - .await; + .await + .ok(); admins.push(username); continue; } @@ -258,7 +273,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is the server service account, skipping over" ))) - .await; + .await + .ok(); continue; } @@ -270,7 +286,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is not a valid username, skipping over: {e}" ))) - .await; + .await + .ok(); continue; }, } @@ -279,7 +296,7 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut deactivation_count: usize = 0; for user_id in user_ids { - match self.services.users.deactivate_account(&user_id) { + match self.services.users.deactivate_account(&user_id).await { Ok(()) => { deactivation_count = deactivation_count.saturating_add(1); if !no_leave_rooms { @@ -289,16 +306,26 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + .map(Into::into) + .collect() + .await; + + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms) + .await + .ok(); + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms) + .await + .ok(); + leave_all_rooms(self.services, &user_id).await; } }, Err(e) => { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}"))) - .await; + .await + .ok(); }, } } @@ -326,9 +353,9 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result(&room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -417,9 +443,9 @@ pub(super) async fn force_demote( .services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if !user_can_demote_self { return Ok(RoomMessageEventContent::notice_markdown( @@ -473,15 +499,16 @@ pub(super) async fn make_user_admin(&self, user_id: String) -> Result, tag: String, ) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -494,12 +521,15 @@ pub(super) async fn put_room_tag( .tags .insert(tag.clone().into(), TagInfo::new()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id} with tag {tag}" @@ -510,15 +540,16 @@ pub(super) async fn put_room_tag( pub(super) async fn delete_room_tag( &self, user_id: String, room_id: Box, tag: String, ) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -528,12 +559,15 @@ pub(super) async fn delete_room_tag( tags_event.content.tags.remove(&tag.clone().into()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id}, deleting room tag {tag}" @@ -542,15 +576,16 @@ pub(super) async fn delete_room_tag( #[admin_command] pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -566,11 +601,12 @@ pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box) #[admin_command] pub(super) async fn redact_event(&self, event_id: Box) -> Result { - let Some(event) = self + let Ok(event) = self .services .rooms .timeline - .get_non_outlier_pdu(&event_id)? + .get_non_outlier_pdu(&event_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Event does not exist in our database.")); }; diff --git a/src/admin/utils.rs b/src/admin/utils.rs index 8d3d15ae4..ba98bbeac 100644 --- a/src/admin/utils.rs +++ b/src/admin/utils.rs @@ -8,23 +8,21 @@ pub(crate) fn escape_html(s: &str) -> String { .replace('>', ">") } -pub(crate) fn get_room_info(services: &Services, id: &RoomId) -> (OwnedRoomId, u64, String) { +pub(crate) async fn get_room_info(services: &Services, room_id: &RoomId) -> (OwnedRoomId, u64, String) { ( - id.into(), + room_id.into(), services .rooms .state_cache - .room_joined_count(id) - .ok() - .flatten() + .room_joined_count(room_id) + .await .unwrap_or(0), services .rooms .state_accessor - .get_name(id) - .ok() - .flatten() - .unwrap_or_else(|| id.to_string()), + .get_name(room_id) + .await + .unwrap_or_else(|_| room_id.to_string()), ) } @@ -46,14 +44,14 @@ pub(crate) fn parse_local_user_id(services: &Services, user_id: &str) -> Result< } /// Parses user ID that is an active (not guest or deactivated) local user -pub(crate) fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result { +pub(crate) async fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result { let user_id = parse_local_user_id(services, user_id)?; - if !services.users.exists(&user_id)? { + if !services.users.exists(&user_id).await { return Err!("User {user_id:?} does not exist on this server."); } - if services.users.is_deactivated(&user_id)? { + if services.users.is_deactivated(&user_id).await? { return Err!("User {user_id:?} is deactivated."); } diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 2b89c3e82..6e37cb407 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -45,7 +45,7 @@ conduit-core.workspace = true conduit-database.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hmac.workspace = true http.workspace = true http-body-util.workspace = true diff --git a/src/api/client/account.rs b/src/api/client/account.rs index cee86f80a..63d02f8f8 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug_info, error, info, utils, warn, Error, PduBuilder, Result}; +use conduit::{debug_info, error, info, is_equal_to, utils, utils::ReadyExt, warn, Error, PduBuilder, Result}; +use futures::{FutureExt, StreamExt}; use register::RegistrationKind; use ruma::{ api::client::{ @@ -55,7 +56,7 @@ pub(crate) async fn get_register_available_route( .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough - if services.users.exists(&user_id)? { + if services.users.exists(&user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -125,7 +126,7 @@ pub(crate) async fn register_route( // forbid guests from registering if there is not a real admin user yet. give // generic user error. - if is_guest && services.users.count()? < 2 { + if is_guest && services.users.count().await < 2 { warn!( "Guest account attempted to register before a real admin user has been registered, rejecting \ registration. Guest's initial device name: {:?}", @@ -142,7 +143,7 @@ pub(crate) async fn register_route( .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - if services.users.exists(&proposed_user_id)? { + if services.users.exists(&proposed_user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -162,7 +163,7 @@ pub(crate) async fn register_route( services.globals.server_name(), ) .unwrap(); - if !services.users.exists(&proposed_user_id)? { + if !services.users.exists(&proposed_user_id).await { break proposed_user_id; } }, @@ -210,12 +211,15 @@ pub(crate) async fn register_route( if !skip_auth { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services.uiaa.try_auth( - &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), - "".into(), - auth, - &uiaainfo, - )?; + let (worked, uiaainfo) = services + .uiaa + .try_auth( + &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), + "".into(), + auth, + &uiaainfo, + ) + .await?; if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -227,7 +231,7 @@ pub(crate) async fn register_route( "".into(), &uiaainfo, &json, - )?; + ); return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -255,21 +259,23 @@ pub(crate) async fn register_route( services .users - .set_displayname(&user_id, Some(displayname.clone())) - .await?; + .set_displayname(&user_id, Some(displayname.clone())); // Initial account data - services.account_data.update( - None, - &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + ) + .await?; // Inhibit login does not work for guests if !is_guest && body.inhibit_login { @@ -294,13 +300,16 @@ pub(crate) async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; debug_info!(%user_id, %device_id, "User account was created"); @@ -318,7 +327,8 @@ pub(crate) async fn register_route( "New user \"{user_id}\" registered on this server from IP {client} and device display name \ \"{device_display_name}\"" ))) - .await; + .await + .ok(); } } else { info!("New user \"{user_id}\" registered on this server."); @@ -329,7 +339,8 @@ pub(crate) async fn register_route( .send_message(RoomMessageEventContent::notice_plain(format!( "New user \"{user_id}\" registered on this server from IP {client}" ))) - .await; + .await + .ok(); } } } @@ -346,7 +357,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with device display name \"{device_display_name}\" registered on \ this server from IP {client}" ))) - .await; + .await + .ok(); } } else { #[allow(clippy::collapsible_else_if)] @@ -357,7 +369,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with no device display name registered on this server from IP \ {client}", ))) - .await; + .await + .ok(); } } } @@ -365,10 +378,15 @@ pub(crate) async fn register_route( // If this is the first real user, grant them admin privileges except for guest // users Note: the server user, @conduit:servername, is generated first if !is_guest { - if let Some(admin_room) = services.admin.get_admin_room()? { - if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { + if let Ok(admin_room) = services.admin.get_admin_room().await { + if services + .rooms + .state_cache + .room_joined_count(&admin_room) + .await + .is_ok_and(is_equal_to!(1)) + { services.admin.make_user_admin(&user_id).await?; - warn!("Granting {user_id} admin privileges as the first user"); } } @@ -382,7 +400,8 @@ pub(crate) async fn register_route( if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room)? + .server_in_room(services.globals.server_name(), room) + .await { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; @@ -398,6 +417,7 @@ pub(crate) async fn register_route( None, &body.appservice_info, ) + .boxed() .await { // don't return this error so we don't fail registrations @@ -461,16 +481,20 @@ pub(crate) async fn change_password_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } - // Success! + + // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -482,14 +506,12 @@ pub(crate) async fn change_password_route( if body.logout_devices { // Logout all devices except the current one - for id in services + services .users .all_device_ids(sender_user) - .filter_map(Result::ok) - .filter(|id| id != sender_device) - { - services.users.remove_device(sender_user, &id)?; - } + .ready_filter(|id| id != sender_device) + .for_each(|id| services.users.remove_device(sender_user, id)) + .await; } info!("User {sender_user} changed their password."); @@ -500,7 +522,8 @@ pub(crate) async fn change_password_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} changed their password." ))) - .await; + .await + .ok(); } Ok(change_password::v3::Response {}) @@ -520,7 +543,7 @@ pub(crate) async fn whoami_route( Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(), + is_guest: services.users.is_deactivated(sender_user).await? && body.appservice_info.is_none(), }) } @@ -561,7 +584,9 @@ pub(crate) async fn deactivate_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -570,7 +595,8 @@ pub(crate) async fn deactivate_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -581,10 +607,14 @@ pub(crate) async fn deactivate_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(&services, sender_user, all_joined_rooms).await?; + super::update_displayname(&services, sender_user, None, &all_joined_rooms).await?; + super::update_avatar_url(&services, sender_user, None, None, &all_joined_rooms).await?; + + full_user_deactivate(&services, sender_user, &all_joined_rooms).await?; info!("User {sender_user} deactivated their account."); @@ -594,7 +624,8 @@ pub(crate) async fn deactivate_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} deactivated their account." ))) - .await; + .await + .ok(); } Ok(deactivate::v3::Response { @@ -674,34 +705,27 @@ pub(crate) async fn check_registration_token_validity( /// - Removing all profile data /// - Leaving all rooms (and forgets all of them) pub async fn full_user_deactivate( - services: &Services, user_id: &UserId, all_joined_rooms: Vec, + services: &Services, user_id: &UserId, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - services.users.deactivate_account(user_id)?; - - super::update_displayname(services, user_id, None, all_joined_rooms.clone()).await?; - super::update_avatar_url(services, user_id, None, None, all_joined_rooms.clone()).await?; + services.users.deactivate_account(user_id).await?; + super::update_displayname(services, user_id, None, all_joined_rooms).await?; + super::update_avatar_url(services, user_id, None, None, all_joined_rooms).await?; - let all_profile_keys = services + services .users .all_profile_keys(user_id) - .filter_map(Result::ok); - - for (profile_key, _profile_value) in all_profile_keys { - if let Err(e) = services.users.set_profile_key(user_id, &profile_key, None) { - warn!("Failed removing {user_id} profile key {profile_key}: {e}"); - } - } + .ready_for_each(|(profile_key, _)| services.users.set_profile_key(user_id, &profile_key, None)) + .await; for room_id in all_joined_rooms { - let state_lock = services.rooms.state.mutex.lock(&room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; let room_power_levels = services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? - .as_ref() - .and_then(|event| serde_json::from_str(event.content.get()).ok()?) - .and_then(|content: RoomPowerLevelsEventContent| content.into()); + .room_state_get_content::(room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -710,9 +734,9 @@ pub async fn full_user_deactivate( }) || services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if user_can_demote_self { let mut power_levels_content = room_power_levels.unwrap_or_default(); @@ -732,7 +756,7 @@ pub async fn full_user_deactivate( timestamp: None, }, user_id, - &room_id, + room_id, &state_lock, ) .await diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 12d6352c9..2399a3551 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -1,11 +1,9 @@ use axum::extract::State; -use conduit::{debug, Error, Result}; +use conduit::{debug, Err, Result}; +use futures::StreamExt; use rand::seq::SliceRandom; use ruma::{ - api::client::{ - alias::{create_alias, delete_alias, get_alias}, - error::ErrorKind, - }, + api::client::alias::{create_alias, delete_alias, get_alias}, OwnedServerName, RoomAliasId, RoomId, }; use service::Services; @@ -33,16 +31,17 @@ pub(crate) async fn create_alias_route( .forbidden_alias_names() .is_match(body.room_alias.alias()) { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden.")); + return Err!(Request(Forbidden("Room alias is forbidden."))); } if services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .is_some() + .resolve_local_alias(&body.room_alias) + .await + .is_ok() { - return Err(Error::Conflict("Alias already exists.")); + return Err!(Conflict("Alias already exists.")); } services @@ -95,16 +94,16 @@ pub(crate) async fn get_alias_route( .resolve_alias(&room_alias, servers.as_ref()) .await else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); + return Err!(Request(NotFound("Room with alias not found."))); }; - let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers); + let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers).await; debug!(?room_alias, ?room_id, "available servers: {servers:?}"); Ok(get_alias::v3::Response::new(room_id, servers)) } -fn room_available_servers( +async fn room_available_servers( services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option>, ) -> Vec { // find active servers in room state cache to suggest @@ -112,8 +111,9 @@ fn room_available_servers( .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; // push any servers we want in the list already (e.g. responded remote alias // servers, room alias server itself) diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index 4ead87776..d52da80a2 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -1,18 +1,16 @@ use axum::extract::State; +use conduit::{err, Err}; use ruma::{ - api::client::{ - backup::{ - add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, - delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, - get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, - get_latest_backup_info, update_backup_version, - }, - error::ErrorKind, + api::client::backup::{ + add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, + delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, + get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, + get_latest_backup_info, update_backup_version, }, UInt, }; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `POST /_matrix/client/r0/room_keys/version` /// @@ -40,7 +38,8 @@ pub(crate) async fn update_backup_version_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); services .key_backups - .update_backup(sender_user, &body.version, &body.algorithm)?; + .update_backup(sender_user, &body.version, &body.algorithm) + .await?; Ok(update_backup_version::v3::Response {}) } @@ -55,14 +54,15 @@ pub(crate) async fn get_latest_backup_info_route( let (version, algorithm) = services .key_backups - .get_latest_backup(sender_user)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_latest_backup(sender_user) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?; Ok(get_latest_backup_info::v3::Response { algorithm, - count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?) + count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version).await) .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &version)?, + etag: services.key_backups.get_etag(sender_user, &version).await, version, }) } @@ -76,18 +76,21 @@ pub(crate) async fn get_backup_info_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let algorithm = services .key_backups - .get_backup(sender_user, &body.version)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_backup(sender_user, &body.version) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))))?; Ok(get_backup_info::v3::Response { algorithm, - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, version: body.version.clone(), }) } @@ -105,7 +108,8 @@ pub(crate) async fn delete_backup_version_route( services .key_backups - .delete_backup(sender_user, &body.version)?; + .delete_backup(sender_user, &body.version) + .await; Ok(delete_backup_version::v3::Response {}) } @@ -123,34 +127,36 @@ pub(crate) async fn add_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { services .key_backups - .add_key(sender_user, &body.version, room_id, session_id, key_data)?; + .add_key(sender_user, &body.version, room_id, session_id, key_data) + .await?; } } Ok(add_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -167,32 +173,34 @@ pub(crate) async fn add_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (session_id, key_data) in &body.sessions { services .key_backups - .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; + .add_key(sender_user, &body.version, &body.room_id, session_id, key_data) + .await?; } Ok(add_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -209,30 +217,32 @@ pub(crate) async fn add_backup_keys_for_session_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } services .key_backups - .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; + .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data) + .await?; Ok(add_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -244,7 +254,10 @@ pub(crate) async fn get_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = services.key_backups.get_all(sender_user, &body.version)?; + let rooms = services + .key_backups + .get_all(sender_user, &body.version) + .await; Ok(get_backup_keys::v3::Response { rooms, @@ -261,7 +274,8 @@ pub(crate) async fn get_backup_keys_for_room_route( let sessions = services .key_backups - .get_room(sender_user, &body.version, &body.room_id)?; + .get_room(sender_user, &body.version, &body.room_id) + .await; Ok(get_backup_keys_for_room::v3::Response { sessions, @@ -278,8 +292,9 @@ pub(crate) async fn get_backup_keys_for_session_route( let key_data = services .key_backups - .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?; + .get_session(sender_user, &body.version, &body.room_id, &body.session_id) + .await + .map_err(|_| err!(Request(NotFound(debug_error!("Backup key not found for this user's session.")))))?; Ok(get_backup_keys_for_session::v3::Response { key_data, @@ -296,16 +311,19 @@ pub(crate) async fn delete_backup_keys_route( services .key_backups - .delete_all_keys(sender_user, &body.version)?; + .delete_all_keys(sender_user, &body.version) + .await; Ok(delete_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -319,16 +337,19 @@ pub(crate) async fn delete_backup_keys_for_room_route( services .key_backups - .delete_room_keys(sender_user, &body.version, &body.room_id)?; + .delete_room_keys(sender_user, &body.version, &body.room_id) + .await; Ok(delete_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -342,15 +363,18 @@ pub(crate) async fn delete_backup_keys_for_session_route( services .key_backups - .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; + .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id) + .await; Ok(delete_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } diff --git a/src/api/client/config.rs b/src/api/client/config.rs index 61cc97ff5..33b85136c 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::err; use ruma::{ api::client::{ config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, @@ -25,7 +26,8 @@ pub(crate) async fn set_global_account_data_route( &body.sender_user, &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_global_account_data::v3::Response {}) } @@ -42,7 +44,8 @@ pub(crate) async fn set_room_account_data_route( &body.sender_user, &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_room_account_data::v3::Response {}) } @@ -57,8 +60,9 @@ pub(crate) async fn get_global_account_data_route( let event: Box = services .account_data - .get(None, sender_user, body.event_type.to_string().into())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + .get(None, sender_user, body.event_type.to_string().into()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -79,8 +83,9 @@ pub(crate) async fn get_room_account_data_route( let event: Box = services .account_data - .get(Some(&body.room_id), sender_user, body.event_type.clone())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + .get(Some(&body.room_id), sender_user, body.event_type.clone()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -91,7 +96,7 @@ pub(crate) async fn get_room_account_data_route( }) } -fn set_account_data( +async fn set_account_data( services: &Services, room_id: Option<&RoomId>, sender_user: &Option, event_type: &str, data: &RawJsonValue, ) -> Result<()> { @@ -100,15 +105,18 @@ fn set_account_data( let data: serde_json::Value = serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; - services.account_data.update( - room_id, - sender_user, - event_type.into(), - &json!({ - "type": event_type, - "content": data, - }), - )?; + services + .account_data + .update( + room_id, + sender_user, + event_type.into(), + &json!({ + "type": event_type, + "content": data, + }), + ) + .await?; Ok(()) } diff --git a/src/api/client/context.rs b/src/api/client/context.rs index f223d4889..cc49b763f 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,13 +1,14 @@ use std::collections::HashSet; use axum::extract::State; +use conduit::{err, error, Err}; +use futures::StreamExt; use ruma::{ - api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, + api::client::{context::get_context, filter::LazyLoadOptions}, events::StateEventType, }; -use tracing::error; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// @@ -35,34 +36,33 @@ pub(crate) async fn get_context_route( let base_token = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; + .get_pdu_count(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Base event id not found."))))?; let base_event = services .rooms .timeline - .get_pdu(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?; + .get_pdu(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Base event not found."))))?; - let room_id = base_event.room_id.clone(); + let room_id = &base_event.room_id; if !services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &body.event_id)? + .user_can_see_event(sender_user, room_id, &body.event_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this event.", - )); + return Err!(Request(Forbidden("You don't have permission to view this event."))); } - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &base_event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &base_event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(base_event.sender.as_str().to_owned()); } @@ -75,25 +75,26 @@ pub(crate) async fn get_context_route( let events_before: Vec<_> = services .rooms .timeline - .pdus_until(sender_user, &room_id, base_token)? + .pdus_until(sender_user, room_id, base_token) + .await? .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect(); + .collect() + .await; for (_, event) in &events_before { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(event.sender.as_str().to_owned()); } @@ -111,25 +112,26 @@ pub(crate) async fn get_context_route( let events_after: Vec<_> = services .rooms .timeline - .pdus_after(sender_user, &room_id, base_token)? + .pdus_after(sender_user, room_id, base_token) + .await? .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect(); + .collect() + .await; for (_, event) in &events_after { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(event.sender.as_str().to_owned()); } @@ -142,12 +144,14 @@ pub(crate) async fn get_context_route( events_after .last() .map_or(&*body.event_id, |(_, e)| &*e.event_id), - )? + ) + .await .map_or( services .rooms .state - .get_room_shortstatehash(&room_id)? + .get_room_shortstatehash(room_id) + .await .expect("All rooms have state"), |hash| hash, ); @@ -156,7 +160,8 @@ pub(crate) async fn get_context_route( .rooms .state_accessor .state_full_ids(shortstatehash) - .await?; + .await + .map_err(|e| err!(Database("State not found: {e}")))?; let end_token = events_after .last() @@ -173,18 +178,19 @@ pub(crate) async fn get_context_route( let (event_type, state_key) = services .rooms .short - .get_statekey_from_short(shortstatekey)?; + .get_statekey_from_short(shortstatekey) + .await?; if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; diff --git a/src/api/client/device.rs b/src/api/client/device.rs index bad7f2844..93eaa393d 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::{err, Err}; +use futures::StreamExt; use ruma::api::client::{ device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, error::ErrorKind, @@ -19,8 +21,8 @@ pub(crate) async fn get_devices_route( let devices: Vec = services .users .all_devices_metadata(sender_user) - .filter_map(Result::ok) // Filter out buggy devices - .collect(); + .collect() + .await; Ok(get_devices::v3::Response { devices, @@ -37,8 +39,9 @@ pub(crate) async fn get_device_route( let device = services .users - .get_device_metadata(sender_user, &body.body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; Ok(get_device::v3::Response { device, @@ -55,14 +58,16 @@ pub(crate) async fn update_device_route( let mut device = services .users - .get_device_metadata(sender_user, &body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; device.display_name.clone_from(&body.display_name); services .users - .update_device_metadata(sender_user, &body.device_id, &device)?; + .update_device_metadata(sender_user, &body.device_id, &device) + .await?; Ok(update_device::v3::Response {}) } @@ -97,22 +102,28 @@ pub(crate) async fn delete_device_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { - return Err(Error::Uiaa(uiaainfo)); + return Err!(Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); + .create(sender_user, sender_device, &uiaainfo, &json); + + return Err!(Uiaa(uiaainfo)); } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + return Err!(Request(NotJson("Not json."))); } - services.users.remove_device(sender_user, &body.device_id)?; + services + .users + .remove_device(sender_user, &body.device_id) + .await; Ok(delete_device::v3::Response {}) } @@ -149,7 +160,9 @@ pub(crate) async fn delete_devices_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -158,14 +171,15 @@ pub(crate) async fn delete_devices_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } for device_id in &body.devices { - services.users.remove_device(sender_user, device_id)?; + services.users.remove_device(sender_user, device_id).await; } Ok(delete_devices::v3::Response {}) diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 602f876a9..ea499545c 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,6 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{err, info, warn, Err, Error, Result}; +use conduit::{info, warn, Err, Error, Result}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ api::{ client::{ @@ -18,7 +19,7 @@ use ruma::{ }, StateEventType, }, - uint, RoomId, ServerName, UInt, UserId, + uint, OwnedRoomId, RoomId, ServerName, UInt, UserId, }; use service::Services; @@ -119,16 +120,22 @@ pub(crate) async fn set_room_visibility_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } - if services.users.is_deactivated(sender_user).unwrap_or(false) && body.appservice_info.is_none() { + if services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && body.appservice_info.is_none() + { return Err!(Request(Forbidden("Guests cannot publish to room directories"))); } - if !user_can_publish_room(&services, sender_user, &body.room_id)? { + if !user_can_publish_room(&services, sender_user, &body.room_id).await? { return Err(Error::BadRequest( ErrorKind::forbidden(), "User is not allowed to publish this room", @@ -138,7 +145,7 @@ pub(crate) async fn set_room_visibility_route( match &body.visibility { room::Visibility::Public => { if services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -164,7 +171,7 @@ pub(crate) async fn set_room_visibility_route( )); } - services.rooms.directory.set_public(&body.room_id)?; + services.rooms.directory.set_public(&body.room_id); if services.globals.config.admin_room_notices { services @@ -174,7 +181,7 @@ pub(crate) async fn set_room_visibility_route( } info!("{sender_user} made {0} public to the room directory", body.room_id); }, - room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?, + room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id), _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -192,13 +199,13 @@ pub(crate) async fn set_room_visibility_route( pub(crate) async fn get_room_visibility_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } Ok(get_room_visibility::v3::Response { - visibility: if services.rooms.directory.is_public_room(&body.room_id)? { + visibility: if services.rooms.directory.is_public_room(&body.room_id).await { room::Visibility::Public } else { room::Visibility::Private @@ -257,101 +264,41 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = services + let mut all_rooms: Vec = services .rooms .directory .public_rooms() - .map(|room_id| { - let room_id = room_id?; - - let chunk = PublicRoomsChunk { - canonical_alias: services - .rooms - .state_accessor - .get_canonical_alias(&room_id)?, - name: services.rooms.state_accessor.get_name(&room_id)?, - num_joined_members: services - .rooms - .state_cache - .room_joined_count(&room_id)? - .unwrap_or_else(|| { - warn!("Room {} has no member count", room_id); - 0 - }) - .try_into() - .expect("user count should not be that big"), - topic: services - .rooms - .state_accessor - .get_room_topic(&room_id) - .unwrap_or(None), - world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?, - guest_can_join: services - .rooms - .state_accessor - .guest_can_join(&room_id)?, - avatar_url: services - .rooms - .state_accessor - .get_avatar(&room_id)? - .into_option() - .unwrap_or_default() - .url, - join_rule: services - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| match c.join_rule { - JoinRule::Public => Some(PublicRoomJoinRule::Public), - JoinRule::Knock => Some(PublicRoomJoinRule::Knock), - _ => None, - }) - .map_err(|e| { - err!(Database(error!("Invalid room join rule event in database: {e}"))) - }) - }) - .transpose()? - .flatten() - .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, - room_type: services - .rooms - .state_accessor - .get_room_type(&room_id)?, - room_id, - }; - Ok(chunk) - }) - .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms - .filter(|chunk| { + .map(ToOwned::to_owned) + .then(|room_id| public_rooms_chunk(services, room_id)) + .filter_map(|chunk| async move { if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) { if let Some(name) = &chunk.name { if name.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(topic) = &chunk.topic { if topic.to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(canonical_alias) = &chunk.canonical_alias { if canonical_alias.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } - false - } else { - // No search term - true + return None; } + + // No search term + Some(chunk) }) // We need to collect all, so we can sort by member count - .collect(); + .collect() + .await; all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); @@ -394,22 +341,23 @@ pub(crate) async fn get_public_rooms_filtered_helper( /// Check whether the user can publish to the room directory via power levels of /// room history visibility event or room creator -fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result { - if let Some(event) = services +async fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result { + if let Ok(event) = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) .map(|content: RoomPowerLevelsEventContent| { RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) }) - } else if let Some(event) = - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? + } else if let Ok(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await { Ok(event.sender == user_id) } else { @@ -419,3 +367,61 @@ fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId )); } } + +async fn public_rooms_chunk(services: &Services, room_id: OwnedRoomId) -> PublicRoomsChunk { + PublicRoomsChunk { + canonical_alias: services + .rooms + .state_accessor + .get_canonical_alias(&room_id) + .await + .ok(), + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), + num_joined_members: services + .rooms + .state_cache + .room_joined_count(&room_id) + .await + .unwrap_or(0) + .try_into() + .expect("joined count overflows ruma UInt"), + topic: services + .rooms + .state_accessor + .get_room_topic(&room_id) + .await + .ok(), + world_readable: services + .rooms + .state_accessor + .is_world_readable(&room_id) + .await, + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + avatar_url: services + .rooms + .state_accessor + .get_avatar(&room_id) + .await + .into_option() + .unwrap_or_default() + .url, + join_rule: services + .rooms + .state_accessor + .room_state_get_content(&room_id, &StateEventType::RoomJoinRules, "") + .map_ok(|c: RoomJoinRulesEventContent| match c.join_rule { + JoinRule::Public => PublicRoomJoinRule::Public, + JoinRule::Knock => PublicRoomJoinRule::Knock, + _ => "invite".into(), + }) + .await + .unwrap_or_default(), + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_id, + } +} diff --git a/src/api/client/filter.rs b/src/api/client/filter.rs index 8b2690c69..2a8ebb9c2 100644 --- a/src/api/client/filter.rs +++ b/src/api/client/filter.rs @@ -1,10 +1,8 @@ use axum::extract::State; -use ruma::api::client::{ - error::ErrorKind, - filter::{create_filter, get_filter}, -}; +use conduit::err; +use ruma::api::client::filter::{create_filter, get_filter}; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// @@ -15,11 +13,13 @@ pub(crate) async fn get_filter_route( State(services): State, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); - }; - Ok(get_filter::v3::Response::new(filter)) + services + .users + .get_filter(sender_user, &body.filter_id) + .await + .map(get_filter::v3::Response::new) + .map_err(|_| err!(Request(NotFound("Filter not found.")))) } /// # `PUT /_matrix/client/r0/user/{userId}/filter` @@ -29,7 +29,8 @@ pub(crate) async fn create_filter_route( State(services): State, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(create_filter::v3::Response::new( - services.users.create_filter(sender_user, &body.filter)?, - )) + + let filter_id = services.users.create_filter(sender_user, &body.filter); + + Ok(create_filter::v3::Response::new(filter_id)) } diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index a426364a2..abf2a22f5 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -4,8 +4,8 @@ use std::{ }; use axum::extract::State; -use conduit::{utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{err, utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; +use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ client::{ @@ -21,7 +21,10 @@ use ruma::{ use serde_json::json; use super::SESSION_ID_LENGTH; -use crate::{service::Services, Ruma}; +use crate::{ + service::{users::parse_master_key, Services}, + Ruma, +}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -39,7 +42,8 @@ pub(crate) async fn upload_keys_route( for (key_key, key_value) in &body.one_time_keys { services .users - .add_one_time_key(sender_user, sender_device, key_key, key_value)?; + .add_one_time_key(sender_user, sender_device, key_key, key_value) + .await?; } if let Some(device_keys) = &body.device_keys { @@ -47,19 +51,22 @@ pub(crate) async fn upload_keys_route( // This check is needed to assure that signatures are kept if services .users - .get_device_keys(sender_user, sender_device)? - .is_none() + .get_device_keys(sender_user, sender_device) + .await + .is_err() { services .users - .add_device_keys(sender_user, sender_device, device_keys)?; + .add_device_keys(sender_user, sender_device, device_keys) + .await; } } Ok(upload_keys::v3::Response { one_time_key_counts: services .users - .count_one_time_keys(sender_user, sender_device)?, + .count_one_time_keys(sender_user, sender_device) + .await, }) } @@ -120,7 +127,9 @@ pub(crate) async fn upload_signing_keys_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -129,20 +138,24 @@ pub(crate) async fn upload_signing_keys_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } if let Some(master_key) = &body.master_key { - services.users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; + services + .users + .add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, // notify so that other users see the new keys + ) + .await?; } Ok(upload_signing_keys::v3::Response {}) @@ -179,9 +192,11 @@ pub(crate) async fn upload_signatures_route( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .to_owned(), ); + services .users - .sign_key(user_id, key_id, signature, sender_user)?; + .sign_key(user_id, key_id, signature, sender_user) + .await?; } } } @@ -204,56 +219,51 @@ pub(crate) async fn get_key_changes_route( let mut device_list_updates = HashSet::new(); + let from = body + .from + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?; + + let to = body + .to + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?; + device_list_updates.extend( services .users - .keys_changed( - sender_user.as_str(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .keys_changed(sender_user.as_str(), from, Some(to)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); - for room_id in services - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(Result::ok) - { + let mut rooms_joined = services.rooms.state_cache.rooms_joined(sender_user).boxed(); + + while let Some(room_id) = rooms_joined.next().await { device_list_updates.extend( services .users - .keys_changed( - room_id.as_ref(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .keys_changed(room_id.as_ref(), from, Some(to)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); } + Ok(get_key_changes::v3::Response { changed: device_list_updates.into_iter().collect(), left: Vec::new(), // TODO }) } -pub(crate) async fn get_keys_helper bool + Send>( +pub(crate) async fn get_keys_helper( services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap>, allowed_signatures: F, include_display_names: bool, -) -> Result { +) -> Result +where + F: Fn(&UserId) -> bool + Send + Sync, +{ let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); let mut user_signing_keys = BTreeMap::new(); @@ -274,56 +284,60 @@ pub(crate) async fn get_keys_helper bool + Send>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in services.users.all_device_ids(user_id) { - let device_id = device_id?; - if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? { + let mut devices = services.users.all_device_ids(user_id).boxed(); + + while let Some(device_id) = devices.next().await { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, &device_id)? - .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Database("all_device_keys contained nonexistent device.")))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; - container.insert(device_id, keys); + container.insert(device_id.to_owned(), keys); } } + device_keys.insert(user_id.to_owned(), container); } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, device_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to get keys for nonexistent device.", - ))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to get keys for nonexistent device."))))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; + container.insert(device_id.to_owned(), keys); } + device_keys.insert(user_id.to_owned(), container); } } - if let Some(master_key) = services + if let Ok(master_key) = services .users - .get_master_key(sender_user, user_id, &allowed_signatures)? + .get_master_key(sender_user, user_id, &allowed_signatures) + .await { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = - services - .users - .get_self_signing_key(sender_user, user_id, &allowed_signatures)? + if let Ok(self_signing_key) = services + .users + .get_self_signing_key(sender_user, user_id, &allowed_signatures) + .await { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? { + if let Ok(user_signing_key) = services.users.get_user_signing_key(user_id).await { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -386,23 +400,26 @@ pub(crate) async fn get_keys_helper bool + Send>( while let Some((server, response)) = futures.next().await { if let Ok(Ok(response)) = response { for (user, masterkey) in response.master_keys { - let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?; + let (master_key_id, mut master_key) = parse_master_key(&user, &masterkey)?; - if let Some(our_master_key) = - services - .users - .get_key(&master_key_id, sender_user, &user, &allowed_signatures)? + if let Ok(our_master_key) = services + .users + .get_key(&master_key_id, sender_user, &user, &allowed_signatures) + .await { - let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?; + let (_, our_master_key) = parse_master_key(&user, &our_master_key)?; master_key.signatures.extend(our_master_key.signatures); } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services.users.add_cross_signing_keys( - &user, &raw, &None, &None, - false, /* Dont notify. A notification would trigger another key request resulting in an - * endless loop */ - )?; + services + .users + .add_cross_signing_keys( + &user, &raw, &None, &None, + false, /* Dont notify. A notification would trigger another key request resulting in an + * endless loop */ + ) + .await?; master_keys.insert(user.clone(), raw); } @@ -465,9 +482,10 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = services + if let Ok(one_time_keys) = services .users - .take_one_time_key(user_id, device_id, key_algorithm)? + .take_one_time_key(user_id, device_id, key_algorithm) + .await { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 470db6693..5a5d436f1 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -11,9 +11,10 @@ use conduit::{ debug, debug_error, debug_warn, err, error, info, pdu::{gen_event_id_canonical_json, PduBuilder}, trace, utils, - utils::math::continue_exponential_backoff_secs, + utils::{math::continue_exponential_backoff_secs, IterStream, ReadyExt}, warn, Err, Error, PduEvent, Result, }; +use futures::{FutureExt, StreamExt}; use ruma::{ api::{ client::{ @@ -55,9 +56,9 @@ async fn banned_room_check( services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>, client_ip: IpAddr, ) -> Result<()> { - if !services.users.is_admin(user_id)? { + if !services.users.is_admin(user_id).await { if let Some(room_id) = room_id { - if services.rooms.metadata.is_banned(room_id)? + if services.rooms.metadata.is_banned(room_id).await || services .globals .config @@ -79,23 +80,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This room is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This room is banned on this homeserver."))); } } else if let Some(server_name) = server_name { if services @@ -119,23 +119,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This remote server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This remote server is banned on this homeserver."))); } } } @@ -172,14 +171,16 @@ pub(crate) async fn join_room_by_id_route( .rooms .state_cache .servers_invite_via(&body.room_id) - .filter_map(Result::ok) - .collect::>(); + .map(ToOwned::to_owned) + .collect::>() + .await; servers.extend( services .rooms .state_cache - .invite_state(sender_user, &body.room_id)? + .invite_state(sender_user, &body.room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -202,6 +203,7 @@ pub(crate) async fn join_room_by_id_route( body.third_party_signed.as_ref(), &body.appservice_info, ) + .boxed() .await } @@ -233,14 +235,17 @@ pub(crate) async fn join_room_by_id_or_alias_route( .rooms .state_cache .servers_invite_via(&room_id) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); servers.extend( services .rooms .state_cache - .invite_state(sender_user, &room_id)? + .invite_state(sender_user, &room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -270,19 +275,23 @@ pub(crate) async fn join_room_by_id_or_alias_route( if let Some(pre_servers) = &mut pre_servers { servers.append(pre_servers); } + servers.extend( services .rooms .state_cache .servers_invite_via(&room_id) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); servers.extend( services .rooms .state_cache - .invite_state(sender_user, &room_id)? + .invite_state(sender_user, &room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -305,6 +314,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( body.third_party_signed.as_ref(), appservice_info, ) + .boxed() .await?; Ok(join_room_by_id_or_alias::v3::Response { @@ -337,7 +347,7 @@ pub(crate) async fn invite_user_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.users.is_admin(sender_user)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() { info!( "User {sender_user} is not an admin and attempted to send an invite to room {}", &body.room_id @@ -375,15 +385,13 @@ pub(crate) async fn kick_user_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot kick member that's not in the room.", - ))? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))? .content .get(), ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + .map_err(|_| err!(Database("Invalid member event in database.")))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -421,10 +429,13 @@ pub(crate) async fn ban_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + let blurhash = services.users.blurhash(&body.user_id).await.ok(); + let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await .map_or( Ok(RoomMemberEventContent { membership: MembershipState::Ban, @@ -432,7 +443,7 @@ pub(crate) async fn ban_user_route( avatar_url: None, is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, }), @@ -442,12 +453,12 @@ pub(crate) async fn ban_user_route( membership: MembershipState::Ban, displayname: None, avatar_url: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, ..event }) - .map_err(|_| Error::bad_database("Invalid member event in database.")) + .map_err(|e| err!(Database("Invalid member event in database: {e:?}"))) }, )?; @@ -488,12 +499,13 @@ pub(crate) async fn unban_user_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest(ErrorKind::BadState, "Cannot unban a user who is not banned."))? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))? .content .get(), ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + .map_err(|e| err!(Database("Invalid member event in database: {e:?}")))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -539,18 +551,16 @@ pub(crate) async fn forget_room_route( if services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "You must leave the room before forgetting it", - )); + return Err!(Request(Unknown("You must leave the room before forgetting it"))); } services .rooms .state_cache - .forget(&body.room_id, sender_user)?; + .forget(&body.room_id, sender_user); Ok(forget_room::v3::Response::new()) } @@ -568,8 +578,9 @@ pub(crate) async fn joined_rooms_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(), + .map(ToOwned::to_owned) + .collect() + .await, }) } @@ -587,12 +598,10 @@ pub(crate) async fn get_member_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } Ok(get_member_events::v3::Response { @@ -622,30 +631,27 @@ pub(crate) async fn joined_members_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } let joined: BTreeMap = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(|user| { - let user = user.ok()?; - - Some(( - user.clone(), + .then(|user| async move { + ( + user.to_owned(), RoomMember { - display_name: services.users.displayname(&user).unwrap_or_default(), - avatar_url: services.users.avatar_url(&user).unwrap_or_default(), + display_name: services.users.displayname(user).await.ok(), + avatar_url: services.users.avatar_url(user).await.ok(), }, - )) + ) }) - .collect(); + .collect() + .await; Ok(joined_members::v3::Response { joined, @@ -658,13 +664,23 @@ pub async fn join_room_by_id_helper( ) -> Result { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let user_is_guest = services.users.is_deactivated(sender_user).unwrap_or(false) && appservice_info.is_none(); + let user_is_guest = services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && appservice_info.is_none(); - if matches!(services.rooms.state_accessor.guest_can_join(room_id), Ok(false)) && user_is_guest { + if user_is_guest && !services.rooms.state_accessor.guest_can_join(room_id).await { return Err!(Request(Forbidden("Guests are not allowed to join this room"))); } - if matches!(services.rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { + if services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { debug_warn!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), @@ -674,15 +690,17 @@ pub async fn join_room_by_id_helper( if services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? - || servers.is_empty() + .server_in_room(services.globals.server_name(), room_id) + .await || servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) { join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .boxed() .await } else { // Ask a remote server if we are not participating in this room join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .boxed() .await } } @@ -739,11 +757,11 @@ async fn join_room_by_id_helper_remote( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server: join_authorized_via_users_server.clone(), }) @@ -791,10 +809,11 @@ async fn join_room_by_id_helper_remote( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), + omit_members: false, pdu: services .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, + .convert_to_outgoing_federation_event(join_event.clone()) + .await, }, ) .await?; @@ -864,7 +883,11 @@ async fn join_room_by_id_helper_remote( } } - services.rooms.short.get_or_create_shortroomid(room_id)?; + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; info!("Parsing join event"); let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) @@ -895,12 +918,13 @@ async fn join_room_by_id_helper_remote( err!(BadServerResponse("Invalid PDU in send_join response: {e:?}")) })?; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); if let Some(state_key) = &pdu.state_key { let shortstatekey = services .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -916,50 +940,53 @@ async fn join_room_by_id_helper_remote( continue; }; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); } debug!("Running send_join auth check"); + let fetch_state = &state; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = fetch_state.get(&shortstatekey)?; + services.rooms.timeline.get_pdu(event_id).await.ok() + }; let auth_check = state_res::event_auth::auth_check( &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), &parsed_join_pdu, - None::, // TODO: third party invite - |k, s| { - services - .rooms - .timeline - .get_pdu( - state.get( - &services - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); + return Err!(Request(Forbidden("Auth check failed"))); } info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services.rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| services.rooms.state_compressor.compress_state_event(k, &id)) - .collect::>()?, - ), - )?; + let (statehash_before_join, new, removed) = services + .rooms + .state_compressor + .save_state( + room_id, + Arc::new( + state + .into_iter() + .stream() + .then(|(k, id)| async move { + services + .rooms + .state_compressor + .compress_state_event(k, &id) + .await + }) + .collect() + .await, + ), + ) + .await?; services .rooms @@ -968,12 +995,20 @@ async fn join_room_by_id_helper_remote( .await?; info!("Updating joined counts for new room"); - services.rooms.state_cache.update_joined_count(room_id)?; + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehash_after_join = services.rooms.state.append_to_state(&parsed_join_pdu)?; + let statehash_after_join = services + .rooms + .state + .append_to_state(&parsed_join_pdu) + .await?; info!("Appending new room join event"); services @@ -993,7 +1028,7 @@ async fn join_room_by_id_helper_remote( services .rooms .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; + .set_room_state(room_id, statehash_after_join, &state_lock); Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } @@ -1005,23 +1040,15 @@ async fn join_room_by_id_helper_local( ) -> Result { debug!("We can join locally"); - let join_rules_event = services + let join_rules_event_content = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; + .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|content: RoomJoinRulesEventContent| content); let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { + Ok(RoomJoinRulesEventContent { join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), }) => restricted .allow @@ -1034,29 +1061,34 @@ async fn join_room_by_id_helper_local( _ => Vec::new(), }; - let local_members = services + let local_members: Vec<_> = services .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|user| services.globals.user_is_local(user)) - .collect::>(); + .ready_filter(|user| services.globals.user_is_local(user)) + .map(ToOwned::to_owned) + .collect() + .await; let mut join_authorized_via_users_server: Option = None; - if restriction_rooms.iter().any(|restriction_room_id| { - services - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { + if restriction_rooms + .iter() + .stream() + .any(|restriction_room_id| { + services + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + }) + .await + { for user in local_members { if services .rooms .state_accessor .user_can_invite(room_id, &user, sender_user, &state_lock) - .unwrap_or(false) + .await { join_authorized_via_users_server = Some(user); break; @@ -1066,11 +1098,11 @@ async fn join_room_by_id_helper_local( let event = RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: reason.clone(), join_authorized_via_users_server, }; @@ -1144,11 +1176,11 @@ async fn join_room_by_id_helper_local( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server, }) @@ -1195,10 +1227,11 @@ async fn join_room_by_id_helper_local( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), + omit_members: false, pdu: services .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, + .convert_to_outgoing_federation_event(join_event.clone()) + .await, }, ) .await?; @@ -1369,7 +1402,7 @@ pub(crate) async fn invite_helper( services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option, is_direct: bool, ) -> Result<()> { - if !services.users.is_admin(user_id)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(user_id).await && services.globals.block_non_admin_invites() { info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -1381,7 +1414,7 @@ pub(crate) async fn invite_helper( let (pdu, pdu_json, invite_room_state) = { let state_lock = services.rooms.state.mutex.lock(room_id).await; let content = to_raw_value(&RoomMemberEventContent { - avatar_url: services.users.avatar_url(user_id)?, + avatar_url: services.users.avatar_url(user_id).await.ok(), displayname: None, is_direct: Some(is_direct), membership: MembershipState::Invite, @@ -1392,28 +1425,32 @@ pub(crate) async fn invite_helper( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - room_id, - &state_lock, - )?; + let (pdu, pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await?; - let invite_room_state = services.rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services.rooms.state.calculate_invite_state(&pdu).await?; drop(state_lock); (pdu, pdu_json, invite_room_state) }; - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let response = services .sending @@ -1425,9 +1462,15 @@ pub(crate) async fn invite_helper( room_version: room_version_id.clone(), event: services .sending - .convert_to_outgoing_federation_event(pdu_json.clone()), + .convert_to_outgoing_federation_event(pdu_json.clone()) + .await, invite_room_state, - via: services.rooms.state_cache.servers_route_via(room_id).ok(), + via: services + .rooms + .state_cache + .servers_route_via(room_id) + .await + .ok(), }, ) .await?; @@ -1478,11 +1521,16 @@ pub(crate) async fn invite_helper( "Could not accept incoming PDU as timeline event.", ))?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; return Ok(()); } - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -1499,11 +1547,11 @@ pub(crate) async fn invite_helper( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Invite, - displayname: services.users.displayname(user_id)?, - avatar_url: services.users.avatar_url(user_id)?, + displayname: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), is_direct: Some(is_direct), third_party_invite: None, - blurhash: services.users.blurhash(user_id)?, + blurhash: services.users.blurhash(user_id).await.ok(), reason, join_authorized_via_users_server: None, }) @@ -1531,36 +1579,37 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { .rooms .state_cache .rooms_joined(user_id) + .map(ToOwned::to_owned) .chain( services .rooms .state_cache .rooms_invited(user_id) - .map(|t| t.map(|(r, _)| r)), + .map(|(r, _)| r), ) - .collect::>(); + .collect::>() + .await; for room_id in all_rooms { - let Ok(room_id) = room_id else { - continue; - }; - // ignore errors if let Err(e) = leave_room(services, user_id, &room_id, None).await { warn!(%room_id, %user_id, %e, "Failed to leave room"); } - if let Err(e) = services.rooms.state_cache.forget(&room_id, user_id) { - warn!(%room_id, %user_id, %e, "Failed to forget room"); - } + + services.rooms.state_cache.forget(&room_id, user_id); } } pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option) -> Result<()> { + //use conduit::utils::stream::OptionStream; + use futures::TryFutureExt; + // Ask a remote server if we don't have this room if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? + .server_in_room(services.globals.server_name(), room_id) + .await { if let Err(e) = remote_leave_room(services, user_id, room_id).await { warn!("Failed to leave room {} remotely: {}", user_id, e); @@ -1570,34 +1619,42 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, let last_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .map_or_else(|| services.rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; + .invite_state(user_id, room_id) + .map_err(|_| services.rooms.state_cache.left_state(user_id, room_id)) + .await + .ok(); // We always drop the invite, we can't rely on other servers - services.rooms.state_cache.update_membership( - room_id, - user_id, - RoomMemberEventContent::new(MembershipState::Leave), - user_id, - last_state, - None, - true, - )?; + services + .rooms + .state_cache + .update_membership( + room_id, + user_id, + RoomMemberEventContent::new(MembershipState::Leave), + user_id, + last_state, + None, + true, + ) + .await?; } else { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let member_event = - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; + let member_event = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await; // Fix for broken rooms - let member_event = match member_event { - None => { - error!("Trying to leave a room you are not a member of."); + let Ok(member_event) = member_event else { + error!("Trying to leave a room you are not a member of."); - services.rooms.state_cache.update_membership( + services + .rooms + .state_cache + .update_membership( room_id, user_id, RoomMemberEventContent::new(MembershipState::Leave), @@ -1605,16 +1662,14 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, None, None, true, - )?; - return Ok(()); - }, - Some(e) => e, + ) + .await?; + + return Ok(()); }; - let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()).map_err(|e| { - error!("Invalid room member event in database: {}", e); - Error::bad_database("Invalid member event in database.") - })?; + let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()) + .map_err(|e| err!(Database(error!("Invalid room member event in database: {e}"))))?; event.membership = MembershipState::Leave; event.reason = reason; @@ -1647,15 +1702,17 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room let invite_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; + .invite_state(user_id, room_id) + .await + .map_err(|_| err!(Request(BadState("User is not invited."))))?; let mut servers: HashSet = services .rooms .state_cache .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.extend( invite_state @@ -1760,7 +1817,8 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room event_id, pdu: services .sending - .convert_to_outgoing_federation_event(leave_event.clone()), + .convert_to_outgoing_federation_event(leave_event.clone()) + .await, }, ) .await?; diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 51aee8c12..bab5fa54f 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,7 +1,8 @@ use std::collections::{BTreeMap, HashSet}; use axum::extract::State; -use conduit::PduCount; +use conduit::{err, utils::ReadyExt, Err, PduCount}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -9,13 +10,14 @@ use ruma::{ message::{get_message_events, send_message_event}, }, events::{MessageLikeEventType, StateEventType}, - RoomId, UserId, + UserId, }; use serde_json::{from_str, Value}; +use service::rooms::timeline::PdusIterItem; use crate::{ service::{pdu::PduBuilder, Services}, - utils, Error, PduEvent, Result, Ruma, + utils, Error, Result, Ruma, }; /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` @@ -30,79 +32,78 @@ use crate::{ pub(crate) async fn send_message_event_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + let appservice_info = body.appservice_info.as_ref(); // Forbid m.room.encrypted if encryption is disabled if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); + return Err!(Request(Forbidden("Encryption has been disabled"))); } - if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Room call invites are not allowed in public rooms", - )); + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + if body.event_type == MessageLikeEventType::CallInvite + && services.rooms.directory.is_public_room(&body.room_id).await + { + return Err!(Request(Forbidden("Room call invites are not allowed in public rooms"))); } // Check if this is a new transaction id - if let Some(response) = services + if let Ok(response) = services .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await { // The client might have sent a txnid of the /sendToDevice endpoint // This txnid has no response associated with it if response.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to use txn id already used for an incompatible endpoint.", - )); + return Err!(Request(InvalidParam( + "Tried to use txn id already used for an incompatible endpoint." + ))); } - let event_id = utils::string_from_bytes(&response) - .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; return Ok(send_message_event::v3::Response { - event_id, + event_id: utils::string_from_bytes(&response) + .map(TryInto::try_into) + .map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??, }); } let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); + let content = from_str(body.body.body.json().get()) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?; + let event_id = services .rooms .timeline .build_and_append_pdu( PduBuilder { event_type: body.event_type.to_string().into(), - content: from_str(body.body.body.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, + content, unsigned: Some(unsigned), state_key: None, redacts: None, - timestamp: if body.appservice_info.is_some() { - body.timestamp - } else { - None - }, + timestamp: appservice_info.and(body.timestamp), }, sender_user, &body.room_id, &state_lock, ) - .await?; + .await + .map(|event_id| (*event_id).to_owned())?; services .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; + .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes()); drop(state_lock); - Ok(send_message_event::v3::Response::new((*event_id).to_owned())) + Ok(send_message_event::v3::Response { + event_id, + }) } /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` @@ -117,8 +118,12 @@ pub(crate) async fn get_message_events_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, + let room_id = &body.room_id; + let filter = &body.filter; + + let limit = usize::try_from(body.limit).unwrap_or(10).min(100); + let from = match body.from.as_ref() { + Some(from) => PduCount::try_from_string(from)?, None => match body.dir { ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Backward => PduCount::max(), @@ -133,30 +138,25 @@ pub(crate) async fn get_message_events_route( services .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) - .await?; - - let limit = usize::try_from(body.limit).unwrap_or(10).min(100); - - let next_token; + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, from); let mut resp = get_message_events::v3::Response::new(); - let mut lazy_loaded = HashSet::new(); - + let next_token; match body.dir { ruma::api::Direction::Forward => { - let events_after: Vec<_> = services + let events_after: Vec = services .rooms .timeline - .pdus_after(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id) - - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .pdus_after(sender_user, room_id, from) + .await? + .ready_filter_map(|item| contains_url_filter(item, filter)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) - .collect(); + .collect() + .boxed() + .await; for (_, event) in &events_after { /* TODO: Remove the not "element_hacks" check when these are resolved: @@ -164,16 +164,18 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { + && !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await + { lazy_loaded.insert(event.sender.clone()); } - lazy_loaded.insert(event.sender.clone()); + if cfg!(features = "element_hacks") { + lazy_loaded.insert(event.sender.clone()); + } } next_token = events_after.last().map(|(count, _)| count).copied(); @@ -191,17 +193,22 @@ pub(crate) async fn get_message_events_route( services .rooms .timeline - .backfill_if_required(&body.room_id, from) + .backfill_if_required(room_id, from) + .boxed() .await?; - let events_before: Vec<_> = services + + let events_before: Vec = services .rooms .timeline - .pdus_until(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)}) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .pdus_until(sender_user, room_id, from) + .await? + .ready_filter_map(|item| contains_url_filter(item, filter)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) - .collect(); + .collect() + .boxed() + .await; for (_, event) in &events_before { /* TODO: Remove the not "element_hacks" check when these are resolved: @@ -209,16 +216,18 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { + && !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await + { lazy_loaded.insert(event.sender.clone()); } - lazy_loaded.insert(event.sender.clone()); + if cfg!(features = "element_hacks") { + lazy_loaded.insert(event.sender.clone()); + } } next_token = events_before.last().map(|(count, _)| count).copied(); @@ -236,11 +245,11 @@ pub(crate) async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { - if let Some(member_event) = - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? + if let Ok(member_event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, ll_id.as_str()) + .await { resp.state.push(member_event.to_state_event()); } @@ -249,34 +258,43 @@ pub(crate) async fn get_message_events_route( // remove the feature check when we are sure clients like element can handle it if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token) - .await; + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_token, + ); } } Ok(resp) } -fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { +async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option { + let (_, pdu) = &item; + services .rooms .state_accessor - .user_can_see_event(user_id, room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(user_id, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) } -fn contains_url_filter(pdu: &PduEvent, filter: &RoomEventFilter) -> bool { +fn contains_url_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { + let (_, pdu) = &item; + if filter.url_filter.is_none() { - return true; + return Some(item); } let content: Value = from_str(pdu.content.get()).unwrap(); - match filter.url_filter { + let res = match filter.url_filter { Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(), Some(UrlFilter::EventsWithUrl) => content["url"].is_string(), None => true, - } + }; + + res.then_some(item) } diff --git a/src/api/client/presence.rs b/src/api/client/presence.rs index 8384d5aca..ba48808bd 100644 --- a/src/api/client/presence.rs +++ b/src/api/client/presence.rs @@ -28,7 +28,8 @@ pub(crate) async fn set_presence_route( services .presence - .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?; + .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone()) + .await?; Ok(set_presence::v3::Response {}) } @@ -49,14 +50,15 @@ pub(crate) async fn get_presence_route( let mut presence_event = None; - for _room_id in services + let has_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - { - if let Some(presence) = services.presence.get_presence(&body.user_id)? { + .has_shared_rooms(sender_user, &body.user_id) + .await; + + if has_shared_rooms { + if let Ok(presence) = services.presence.get_presence(&body.user_id).await { presence_event = Some(presence); - break; } } diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index bf47a3f85..495bc8ec3 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -1,5 +1,10 @@ use axum::extract::State; -use conduit::{pdu::PduBuilder, warn, Err, Error, Result}; +use conduit::{ + pdu::PduBuilder, + utils::{stream::TryIgnore, IterStream}, + warn, Err, Error, Result, +}; +use futures::{StreamExt, TryStreamExt}; use ruma::{ api::{ client::{ @@ -35,16 +40,18 @@ pub(crate) async fn set_displayname_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; - update_displayname(&services, &body.user_id, body.displayname.clone(), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, body.displayname.clone(), &all_joined_rooms).await?; if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_display_name::v3::Response {}) @@ -72,22 +79,19 @@ pub(crate) async fn get_displayname_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_display_name::v3::Response { displayname: response.displayname, @@ -95,14 +99,14 @@ pub(crate) async fn get_displayname_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_display_name::v3::Response { - displayname: services.users.displayname(&body.user_id)?, + displayname: services.users.displayname(&body.user_id).await.ok(), }) } @@ -124,15 +128,16 @@ pub(crate) async fn set_avatar_url_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; update_avatar_url( &services, &body.user_id, body.avatar_url.clone(), body.blurhash.clone(), - all_joined_rooms, + &all_joined_rooms, ) .await?; @@ -140,7 +145,9 @@ pub(crate) async fn set_avatar_url_route( // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await + .ok(); } Ok(set_avatar_url::v3::Response {}) @@ -168,22 +175,21 @@ pub(crate) async fn get_avatar_url_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_avatar_url::v3::Response { avatar_url: response.avatar_url, @@ -192,15 +198,15 @@ pub(crate) async fn get_avatar_url_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_avatar_url::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), }) } @@ -226,31 +232,30 @@ pub(crate) async fn get_profile_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); for (profile_key, profile_key_value) in &response.custom_profile_fields { services .users - .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone())); } return Ok(get_profile::v3::Response { @@ -263,104 +268,93 @@ pub(crate) async fn get_profile_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_profile::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, - displayname: services.users.displayname(&body.user_id)?, - tz: services.users.timezone(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), + displayname: services.users.displayname(&body.user_id).await.ok(), + tz: services.users.timezone(&body.user_id).await.ok(), custom_profile_fields: services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(), + .collect() + .await, }) } pub async fn update_displayname( - services: &Services, user_id: &UserId, displayname: Option, all_joined_rooms: Vec, + services: &Services, user_id: &UserId, displayname: Option, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_display_name = services.users.displayname(user_id).unwrap_or_default(); + let current_display_name = services.users.displayname(user_id).await.ok(); if displayname == current_display_name { return Ok(()); } - services - .users - .set_displayname(user_id, displayname.clone()) - .await?; + services.users.set_displayname(user_id, displayname.clone()); // Send a new join membership event into all joined rooms - let all_joined_rooms: Vec<_> = all_joined_rooms - .iter() - .map(|room_id| { - Ok::<_, Error>(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - displayname: displayname.clone(), - join_authorized_via_users_server: None, - ..serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { - Error::bad_database("Tried to send display name update for user not in the room.") - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - room_id, - )) - }) - .filter_map(Result::ok) - .collect(); + let mut joined_rooms = Vec::new(); + for room_id in all_joined_rooms { + let Ok(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + else { + continue; + }; + + let pdu = PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + displayname: displayname.clone(), + join_authorized_via_users_server: None, + ..serde_json::from_str(event.content.get()).expect("Database contains invalid PDU.") + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }; + + joined_rooms.push((pdu, room_id)); + } - update_all_rooms(services, all_joined_rooms, user_id).await; + update_all_rooms(services, joined_rooms, user_id).await; Ok(()) } pub async fn update_avatar_url( services: &Services, user_id: &UserId, avatar_url: Option, blurhash: Option, - all_joined_rooms: Vec, + all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_avatar_url = services.users.avatar_url(user_id).unwrap_or_default(); - let current_blurhash = services.users.blurhash(user_id).unwrap_or_default(); + let current_avatar_url = services.users.avatar_url(user_id).await.ok(); + let current_blurhash = services.users.blurhash(user_id).await.ok(); if current_avatar_url == avatar_url && current_blurhash == blurhash { return Ok(()); } - services - .users - .set_avatar_url(user_id, avatar_url.clone()) - .await?; - services - .users - .set_blurhash(user_id, blurhash.clone()) - .await?; + services.users.set_avatar_url(user_id, avatar_url.clone()); + + services.users.set_blurhash(user_id, blurhash.clone()); // Send a new join membership event into all joined rooms + let avatar_url = &avatar_url; + let blurhash = &blurhash; let all_joined_rooms: Vec<_> = all_joined_rooms .iter() - .map(|room_id| { - Ok::<_, Error>(( + .try_stream() + .and_then(|room_id: &OwnedRoomId| async move { + Ok(( PduBuilder { event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { @@ -371,8 +365,9 @@ pub async fn update_avatar_url( services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_err(|_| { Error::bad_database("Tried to send avatar URL update for user not in the room.") })? .content @@ -389,8 +384,9 @@ pub async fn update_avatar_url( room_id, )) }) - .filter_map(Result::ok) - .collect(); + .ignore_err() + .collect() + .await; update_all_rooms(services, all_joined_rooms, user_id).await; diff --git a/src/api/client/push.rs b/src/api/client/push.rs index 8723e676b..390951999 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -29,40 +29,36 @@ pub(crate) async fn get_pushrules_all_route( let global_ruleset: Ruleset; - let Ok(event) = - services - .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) - else { - // push rules event doesn't exist, create it and return default - return recreate_push_rules_and_return(&services, sender_user); + let event = services + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await; + + let Ok(event) = event else { + // user somehow has non-existent push rule event. recreate it and return server + // default silently + return recreate_push_rules_and_return(&services, sender_user).await; }; - if let Some(event) = event { - let value = serde_json::from_str::(event.get()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + let value = serde_json::from_str::(event.get()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - let Some(content_value) = value.get("content") else { - // user somehow has a push rule event with no content key, recreate it and - // return server default silently - return recreate_push_rules_and_return(&services, sender_user); - }; + let Some(content_value) = value.get("content") else { + // user somehow has a push rule event with no content key, recreate it and + // return server default silently + return recreate_push_rules_and_return(&services, sender_user).await; + }; - if content_value.to_string().is_empty() { - // user somehow has a push rule event with empty content, recreate it and return - // server default silently - return recreate_push_rules_and_return(&services, sender_user); - } + if content_value.to_string().is_empty() { + // user somehow has a push rule event with empty content, recreate it and return + // server default silently + return recreate_push_rules_and_return(&services, sender_user).await; + } - let account_data_content = serde_json::from_value::(content_value.clone().into()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + let account_data_content = serde_json::from_value::(content_value.clone().into()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - global_ruleset = account_data_content.global; - } else { - // user somehow has non-existent push rule event. recreate it and return server - // default silently - return recreate_push_rules_and_return(&services, sender_user); - } + global_ruleset = account_data_content.global; Ok(get_pushrules_all::v3::Response { global: global_ruleset, @@ -79,8 +75,9 @@ pub(crate) async fn get_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -118,8 +115,9 @@ pub(crate) async fn set_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -155,12 +153,15 @@ pub(crate) async fn set_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule::v3::Response {}) } @@ -182,8 +183,9 @@ pub(crate) async fn get_pushrule_actions_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -217,8 +219,9 @@ pub(crate) async fn set_pushrule_actions_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -232,12 +235,15 @@ pub(crate) async fn set_pushrule_actions_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_actions::v3::Response {}) } @@ -259,8 +265,9 @@ pub(crate) async fn get_pushrule_enabled_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -293,8 +300,9 @@ pub(crate) async fn set_pushrule_enabled_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -308,12 +316,15 @@ pub(crate) async fn set_pushrule_enabled_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_enabled::v3::Response {}) } @@ -335,8 +346,9 @@ pub(crate) async fn delete_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -357,12 +369,15 @@ pub(crate) async fn delete_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(delete_pushrule::v3::Response {}) } @@ -376,7 +391,7 @@ pub(crate) async fn get_pushers_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: services.pusher.get_pushers(sender_user)?, + pushers: services.pusher.get_pushers(sender_user).await, }) } @@ -390,27 +405,30 @@ pub(crate) async fn set_pushers_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services.pusher.set_pusher(sender_user, &body.action)?; + services.pusher.set_pusher(sender_user, &body.action); Ok(set_pusher::v3::Response::default()) } /// user somehow has bad push rules, these must always exist per spec. /// so recreate it and return server default silently -fn recreate_push_rules_and_return( +async fn recreate_push_rules_and_return( services: &Services, sender_user: &ruma::UserId, ) -> Result { - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(PushRulesEvent { - content: PushRulesEventContent { - global: Ruleset::server_default(sender_user), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) + .await?; Ok(get_pushrules_all::v3::Response { global: Ruleset::server_default(sender_user), diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index f40f24932..f28b2aec5 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -31,27 +31,32 @@ pub(crate) async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } if let Some(event) = &body.private_read_receipt { let count = services .rooms .timeline - .get_pdu_count(event)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(event) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -64,7 +69,7 @@ pub(crate) async fn set_read_marker_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); } if let Some(event) = &body.read_receipt { @@ -83,14 +88,18 @@ pub(crate) async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; } Ok(set_read_marker::v3::Response {}) @@ -111,7 +120,7 @@ pub(crate) async fn create_receipt_route( services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } match body.receipt_type { @@ -121,12 +130,15 @@ pub(crate) async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; }, create_receipt::v3::ReceiptType::Read => { let mut user_receipts = BTreeMap::new(); @@ -143,21 +155,27 @@ pub(crate) async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.clone(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; }, create_receipt::v3::ReceiptType::ReadPrivate => { let count = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -170,7 +188,7 @@ pub(crate) async fn create_receipt_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); }, _ => return Err(Error::bad_database("Unsupported receipt type")), } diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index ae6459400..d43847300 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -9,20 +9,24 @@ use crate::{Result, Ruma}; pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &Some(body.event_type.clone()), - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + body.event_type.clone().into(), + body.rel_type.clone().into(), + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await?; Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, @@ -36,20 +40,24 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + body.rel_type.clone().into(), + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -63,18 +71,22 @@ pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &None, - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - ) + services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + None, + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await } diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 588bd3686..a40c35a28 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -1,6 +1,7 @@ use std::time::Duration; use axum::extract::State; +use conduit::{utils::ReadyExt, Err}; use rand::Rng; use ruma::{ api::client::{error::ErrorKind, room::report_content}, @@ -34,11 +35,8 @@ pub(crate) async fn report_event_route( delay_response().await; // check if we know about the reported event ID or if it's invalid - let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Event ID is not known to us or Event ID is invalid", - )); + let Ok(pdu) = services.rooms.timeline.get_pdu(&body.event_id).await else { + return Err!(Request(NotFound("Event ID is not known to us or Event ID is invalid"))); }; is_report_valid( @@ -49,7 +47,8 @@ pub(crate) async fn report_event_route( &body.reason, body.score, &pdu, - )?; + ) + .await?; // send admin room message that we received the report with an @room ping for // urgency @@ -81,7 +80,8 @@ pub(crate) async fn report_event_route( HtmlEscape(body.reason.as_deref().unwrap_or("")) ), )) - .await; + .await + .ok(); Ok(report_content::v3::Response {}) } @@ -92,7 +92,7 @@ pub(crate) async fn report_event_route( /// check if score is in valid range /// check if report reasoning is less than or equal to 750 characters /// check if reporting user is in the reporting room -fn is_report_valid( +async fn is_report_valid( services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option, score: Option, pdu: &std::sync::Arc, ) -> Result<()> { @@ -123,8 +123,8 @@ fn is_report_valid( .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .any(|user_id| user_id == *sender_user) + .ready_any(|user_id| user_id == sender_user) + .await { return Err(Error::BadRequest( ErrorKind::NotFound, diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 0112e76dc..1edf85d80 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -2,6 +2,7 @@ use std::{cmp::max, collections::BTreeMap}; use axum::extract::State; use conduit::{debug_info, debug_warn, err, Err}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -74,7 +75,7 @@ pub(crate) async fn create_room_route( if !services.globals.allow_room_creation() && body.appservice_info.is_none() - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled.")); } @@ -86,7 +87,7 @@ pub(crate) async fn create_room_route( }; // check if room ID doesn't already exist instead of erroring on auth check - if services.rooms.short.get_shortroomid(&room_id)?.is_some() { + if services.rooms.short.get_shortroomid(&room_id).await.is_ok() { return Err(Error::BadRequest( ErrorKind::RoomInUse, "Room with that custom room ID already exists", @@ -95,7 +96,7 @@ pub(crate) async fn create_room_route( if body.visibility == room::Visibility::Public && services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -118,7 +119,11 @@ pub(crate) async fn create_room_route( return Err!(Request(Forbidden("Publishing rooms to the room directory is not allowed"))); } - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; let alias: Option = if let Some(alias) = &body.room_alias_name { @@ -218,6 +223,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 2. Let the room creator join @@ -229,11 +235,11 @@ pub(crate) async fn create_room_route( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: Some(body.is_direct), third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: None, join_authorized_via_users_server: None, }) @@ -247,6 +253,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 3. Power levels @@ -284,6 +291,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 4. Canonical room alias @@ -308,6 +316,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } @@ -335,6 +344,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 5.2 History Visibility @@ -355,6 +365,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 5.3 Guest Access @@ -378,6 +389,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 6. Events listed in initial_state @@ -410,6 +422,7 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .boxed() .await?; } @@ -432,6 +445,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } @@ -455,13 +469,17 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct).await { + if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct) + .boxed() + .await + { warn!(%e, "Failed to send invite"); } } @@ -475,7 +493,7 @@ pub(crate) async fn create_room_route( } if body.visibility == room::Visibility::Public { - services.rooms.directory.set_public(&room_id)?; + services.rooms.directory.set_public(&room_id); if services.globals.config.admin_room_notices { services @@ -505,13 +523,15 @@ pub(crate) async fn get_room_event_route( let event = services .rooms .timeline - .get_pdu(&body.event_id)? - .ok_or_else(|| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; + .get_pdu(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; if !services .rooms .state_accessor - .user_can_see_event(sender_user, &event.room_id, &body.event_id)? + .user_can_see_event(sender_user, &event.room_id, &body.event_id) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -541,7 +561,8 @@ pub(crate) async fn get_room_aliases_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -554,8 +575,9 @@ pub(crate) async fn get_room_aliases_route( .rooms .alias .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - .collect(), + .map(ToOwned::to_owned) + .collect() + .await, }) } @@ -591,7 +613,8 @@ pub(crate) async fn upgrade_room_route( let _short_id = services .rooms .short - .get_or_create_shortroomid(&replacement_room)?; + .get_or_create_shortroomid(&replacement_room) + .await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -629,12 +652,12 @@ pub(crate) async fn upgrade_room_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .room_state_get(&body.room_id, &StateEventType::RoomCreate, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))? .content .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + )?; // Use the m.room.tombstone event as the predecessor let predecessor = Some(ruma::events::room::create::PreviousRoom::new( @@ -714,11 +737,11 @@ pub(crate) async fn upgrade_room_route( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: None, join_authorized_via_users_server: None, }) @@ -739,10 +762,11 @@ pub(crate) async fn upgrade_room_route( let event_content = match services .rooms .state_accessor - .room_state_get(&body.room_id, event_type, "")? + .room_state_get(&body.room_id, event_type, "") + .await { - Some(v) => v.content.clone(), - None => continue, // Skipping missing events. + Ok(v) => v.content.clone(), + Err(_) => continue, // Skipping missing events. }; services @@ -765,21 +789,23 @@ pub(crate) async fn upgrade_room_route( } // Moves any local aliases to the new room - for alias in services + let mut local_aliases = services .rooms .alias .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - { + .boxed(); + + while let Some(alias) = local_aliases.next().await { services .rooms .alias - .remove_alias(&alias, sender_user) + .remove_alias(alias, sender_user) .await?; + services .rooms .alias - .set_alias(&alias, &replacement_room, sender_user)?; + .set_alias(alias, &replacement_room, sender_user)?; } // Get the old room power levels @@ -787,12 +813,12 @@ pub(crate) async fn upgrade_room_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))? .content .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + )?; // Setting events_default and invite to the greater of 50 and users_default + 1 let new_level = max( @@ -800,9 +826,7 @@ pub(crate) async fn upgrade_room_route( power_levels_event_content .users_default .checked_add(int!(1)) - .ok_or_else(|| { - Error::BadRequest(ErrorKind::BadJson, "users_default power levels event content is not valid") - })?, + .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?, ); power_levels_event_content.events_default = new_level; power_levels_event_content.invite = new_level; @@ -921,8 +945,9 @@ async fn room_alias_check( if services .rooms .alias - .resolve_local_alias(&full_room_alias)? - .is_some() + .resolve_local_alias(&full_room_alias) + .await + .is_ok() { return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")); } diff --git a/src/api/client/search.rs b/src/api/client/search.rs index b143bd2c7..7a061d494 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -1,6 +1,12 @@ use std::collections::BTreeMap; use axum::extract::State; +use conduit::{ + debug, + utils::{IterStream, ReadyExt}, + Err, +}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -13,7 +19,6 @@ use ruma::{ serde::Raw, uint, OwnedRoomId, }; -use tracing::debug; use crate::{Error, Result, Ruma}; @@ -32,14 +37,17 @@ pub(crate) async fn search_events_route( let filter = &search_criteria.filter; let include_state = &search_criteria.include_state; - let room_ids = filter.rooms.clone().unwrap_or_else(|| { + let room_ids = if let Some(room_ids) = &filter.rooms { + room_ids.clone() + } else { services .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) + .map(ToOwned::to_owned) .collect() - }); + .await + }; // Use limit or else 10, with maximum 100 let limit: usize = filter @@ -53,18 +61,21 @@ pub(crate) async fn search_events_route( if include_state.is_some_and(|include_state| include_state) { for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { + return Err!(Request(Forbidden("You don't have permission to view this room."))); } // check if sender_user can see state events if services .rooms .state_accessor - .user_can_see_state_events(sender_user, room_id)? + .user_can_see_state_events(sender_user, room_id) + .await { let room_state = services .rooms @@ -87,10 +98,15 @@ pub(crate) async fn search_events_route( } } - let mut searches = Vec::new(); + let mut search_vecs = Vec::new(); for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -100,12 +116,18 @@ pub(crate) async fn search_events_route( if let Some(search) = services .rooms .search - .search_pdus(room_id, &search_criteria.search_term)? + .search_pdus(room_id, &search_criteria.search_term) + .await { - searches.push(search.0.peekable()); + search_vecs.push(search.0); } } + let mut searches: Vec<_> = search_vecs + .iter() + .map(|vec| vec.iter().peekable()) + .collect(); + let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) { Some(Ok(s)) => s, Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), @@ -118,8 +140,8 @@ pub(crate) async fn search_events_route( for _ in 0..next_batch { if let Some(s) = searches .iter_mut() - .map(|s| (s.peek().cloned(), s)) - .max_by_key(|(peek, _)| peek.clone()) + .map(|s| (s.peek().copied(), s)) + .max_by_key(|(peek, _)| *peek) .and_then(|(_, i)| i.next()) { results.push(s); @@ -127,42 +149,38 @@ pub(crate) async fn search_events_route( } let results: Vec<_> = results - .iter() + .into_iter() .skip(skip) - .filter_map(|result| { + .stream() + .filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok)) + .ready_filter(|pdu| !pdu.is_redacted()) + .filter_map(|pdu| async move { services .rooms - .timeline - .get_pdu_from_id(result) - .ok()? - .filter(|pdu| { - !pdu.is_redacted() - && services - .rooms - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .map(|pdu| pdu.to_room_event()) - }) - .map(|result| { - Ok::<_, Error>(SearchResult { - context: EventContextResult { - end: None, - events_after: Vec::new(), - events_before: Vec::new(), - profile_info: BTreeMap::new(), - start: None, - }, - rank: None, - result: Some(result), - }) + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) }) - .filter_map(Result::ok) .take(limit) - .collect(); + .map(|pdu| pdu.to_room_event()) + .map(|result| SearchResult { + context: EventContextResult { + end: None, + events_after: Vec::new(), + events_before: Vec::new(), + profile_info: BTreeMap::new(), + start: None, + }, + rank: None, + result: Some(result), + }) + .collect() + .boxed() + .await; let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some()); + let next_batch = more_unloaded_results.then(|| next_batch.to_string()); Ok(search_events::v3::Response::new(ResultCategories { diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 4702b0ec1..6347a2c95 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -1,5 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; +use conduit::{debug, err, info, utils::ReadyExt, warn, Err}; +use futures::StreamExt; use ruma::{ api::client::{ error::ErrorKind, @@ -19,7 +21,6 @@ use ruma::{ UserId, }; use serde::Deserialize; -use tracing::{debug, info, warn}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{utils, utils::hash, Error, Result, Ruma}; @@ -79,21 +80,22 @@ pub(crate) async fn login_route( UserId::parse(user) } else { warn!("Bad login type: {:?}", &body.login_info); - return Err(Error::BadRequest(ErrorKind::forbidden(), "Bad login type.")); + return Err!(Request(Forbidden("Bad login type."))); } .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; let hash = services .users - .password_hash(&user_id)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?; + .password_hash(&user_id) + .await + .map_err(|_| err!(Request(Forbidden("Wrong username or password."))))?; if hash.is_empty() { - return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated")); + return Err!(Request(UserDeactivated("The user has been deactivated"))); } if hash::verify_password(password, &hash).is_err() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password.")); + return Err!(Request(Forbidden("Wrong username or password."))); } user_id @@ -112,15 +114,12 @@ pub(crate) async fn login_route( let username = token.claims.sub.to_lowercase(); - UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| { - warn!("Failed to parse username from user logging in: {e}"); - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })? + UserId::parse_with_server_name(username, services.globals.server_name()) + .map_err(|e| err!(Request(InvalidUsername(debug_error!(?e, "Failed to parse login username")))))? } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); + return Err!(Request(Unknown( + "Token login is not supported (server has no jwt decoding key)." + ))); } }, #[allow(deprecated)] @@ -169,23 +168,32 @@ pub(crate) async fn login_route( let token = utils::random_string(TOKEN_LENGTH); // Determine if device_id was provided and exists in the db for this user - let device_exists = body.device_id.as_ref().map_or(false, |device_id| { + let device_exists = if body.device_id.is_some() { services .users .all_device_ids(&user_id) - .any(|x| x.as_ref().map_or(false, |v| v == device_id)) - }); + .ready_any(|v| v == device_id) + .await + } else { + false + }; if device_exists { - services.users.set_token(&user_id, &device_id, &token)?; + services + .users + .set_token(&user_id, &device_id, &token) + .await?; } else { - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; } // send client well-known if specified so the client knows to reconfigure itself @@ -228,10 +236,13 @@ pub(crate) async fn logout_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - services.users.remove_device(sender_user, sender_device)?; + services + .users + .remove_device(sender_user, sender_device) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout::v3::Response::new()) } @@ -256,12 +267,14 @@ pub(crate) async fn logout_all_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in services.users.all_device_ids(sender_user).flatten() { - services.users.remove_device(sender_user, &device_id)?; - } + services + .users + .all_device_ids(sender_user) + .for_each(|device_id| services.users.remove_device(sender_user, device_id)) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client/state.rs b/src/api/client/state.rs index fd0496639..f9a4a7636 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{debug_info, error, pdu::PduBuilder, Error, Result}; +use conduit::{err, error, pdu::PduBuilder, Err, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, @@ -84,12 +84,10 @@ pub(crate) async fn get_state_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } Ok(get_state_events::v3::Response { @@ -120,22 +118,25 @@ pub(crate) async fn get_state_events_for_key_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &body.event_type, &body.state_key)? - .ok_or_else(|| { - debug_info!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); - Error::BadRequest(ErrorKind::NotFound, "State event not found.") + .room_state_get(&body.room_id, &body.event_type, &body.state_key) + .await + .map_err(|_| { + err!(Request(NotFound(error!( + room_id = ?body.room_id, + event_type = ?body.event_type, + "State event not found in room.", + )))) })?; + if body .format .as_ref() @@ -204,7 +205,7 @@ async fn send_state_event_for_key_helper( async fn allowed_to_send_state_event( services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw, -) -> Result<()> { +) -> Result { match event_type { // Forbid m.room.encryption if encryption is disabled StateEventType::RoomEncryption => { @@ -214,7 +215,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made public StateEventType::RoomJoinRules => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(join_rule) = serde_json::from_str::(json.json().get()) { if join_rule.join_rule == JoinRule::Public { @@ -229,7 +230,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made world readable StateEventType::RoomHistoryVisibility => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(visibility_content) = serde_json::from_str::(json.json().get()) @@ -254,23 +255,27 @@ async fn allowed_to_send_state_event( } for alias in aliases { - if !services.globals.server_is_ours(alias.server_name()) - || services - .rooms - .alias - .resolve_local_alias(&alias)? - .filter(|room| room == room_id) // Make sure it's the right room - .is_none() + if !services.globals.server_is_ours(alias.server_name()) { + return Err!(Request(Forbidden("canonical_alias must be for this server"))); + } + + if !services + .rooms + .alias + .resolve_local_alias(&alias) + .await + .is_ok_and(|room| room == room_id) + // Make sure it's the right room { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You are only allowed to send canonical_alias events when its aliases already exist", - )); + return Err!(Request(Forbidden( + "You are only allowed to send canonical_alias events when its aliases already exist" + ))); } } } }, _ => (), } + Ok(()) } diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index eb534205e..53d4f3c35 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -6,10 +6,14 @@ use std::{ use axum::extract::State; use conduit::{ - debug, error, - utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, - warn, Err, PduCount, + debug, err, error, is_equal_to, + utils::{ + math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, + IterStream, ReadyExt, + }, + warn, PduCount, }; +use futures::{pin_mut, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -108,7 +112,8 @@ pub(crate) async fn sync_events_route( if services.globals.allow_local_presence() { services .presence - .ping_presence(&sender_user, &body.set_presence)?; + .ping_presence(&sender_user, &body.set_presence) + .await?; } // Setup watchers, so if there's no response, we can wait for them @@ -124,7 +129,8 @@ pub(crate) async fn sync_events_route( Some(Filter::FilterDefinition(filter)) => filter, Some(Filter::FilterId(filter_id)) => services .users - .get_filter(&sender_user, &filter_id)? + .get_filter(&sender_user, &filter_id) + .await .unwrap_or_default(), }; @@ -157,7 +163,9 @@ pub(crate) async fn sync_events_route( services .users .keys_changed(sender_user.as_ref(), since, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); if services.globals.allow_local_presence() { @@ -168,13 +176,14 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_joined(&sender_user) - .collect::>(); + .map(ToOwned::to_owned) + .collect::>() + .await; // Coalesce database writes for the remainder of this scope. let _cork = services.db.cork_and_flush(); for room_id in all_joined_rooms { - let room_id = room_id?; if let Ok(joined_room) = load_joined_room( &services, &sender_user, @@ -203,12 +212,14 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_left(&sender_user) - .collect(); + .collect() + .await; + for result in all_left_rooms { handle_left_room( &services, since, - &result?.0, + &result.0, &sender_user, &mut left_rooms, &next_batch_string, @@ -224,10 +235,10 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_invited(&sender_user) - .collect(); - for result in all_invited_rooms { - let (room_id, invite_state_events) = result?; + .collect() + .await; + for (room_id, invite_state_events) in all_invited_rooms { // Get and drop the lock to wait for remaining operations to finish let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; drop(insert_lock); @@ -235,7 +246,9 @@ pub(crate) async fn sync_events_route( let invite_count = services .rooms .state_cache - .get_invite_count(&room_id, &sender_user)?; + .get_invite_count(&room_id, &sender_user) + .await + .ok(); // Invited before last sync if Some(since) >= invite_count { @@ -253,22 +266,8 @@ pub(crate) async fn sync_events_route( } for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); + let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + // If the user doesn't share an encrypted room with the target anymore, we need // to tell them if dont_share_encrypted_room { @@ -279,7 +278,8 @@ pub(crate) async fn sync_events_route( // Remove all to-device events the device received *last time* services .users - .remove_to_device_events(&sender_user, &sender_device, since)?; + .remove_to_device_events(&sender_user, &sender_device, since) + .await; let response = sync_events::v3::Response { next_batch: next_batch_string, @@ -298,7 +298,8 @@ pub(crate) async fn sync_events_route( account_data: GlobalAccountData { events: services .account_data - .changes_since(None, &sender_user, since)? + .changes_since(None, &sender_user, since) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) .collect(), @@ -309,11 +310,14 @@ pub(crate) async fn sync_events_route( }, device_one_time_keys_count: services .users - .count_one_time_keys(&sender_user, &sender_device)?, + .count_one_time_keys(&sender_user, &sender_device) + .await, to_device: ToDevice { events: services .users - .get_to_device_events(&sender_user, &sender_device)?, + .get_to_device_events(&sender_user, &sender_device) + .collect() + .await, }, // Fallback keys are not yet supported device_unused_fallback_key_types: None, @@ -351,14 +355,16 @@ async fn handle_left_room( let left_count = services .rooms .state_cache - .get_left_count(room_id, sender_user)?; + .get_left_count(room_id, sender_user) + .await + .ok(); // Left before last sync if Some(since) >= left_count { return Ok(()); } - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { // This is just a rejected invite, not a room we know // Insert a leave event anyways let event = PduEvent { @@ -408,27 +414,29 @@ async fn handle_left_room( let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, since)?; + .get_token_shortstatehash(room_id, since) + .await; let since_state_ids = match since_shortstatehash { - Some(s) => services.rooms.state_accessor.state_full_ids(s).await?, - None => HashMap::new(), + Ok(s) => services.rooms.state_accessor.state_full_ids(s).await?, + Err(_) => HashMap::new(), }; - let Some(left_event_id) = - services - .rooms - .state_accessor - .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())? + let Ok(left_event_id) = services + .rooms + .state_accessor + .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str()) + .await else { error!("Left room but no left state event"); return Ok(()); }; - let Some(left_shortstatehash) = services + let Ok(left_shortstatehash) = services .rooms .state_accessor - .pdu_shortstatehash(&left_event_id)? + .pdu_shortstatehash(&left_event_id) + .await else { error!(event_id = %left_event_id, "Leave event has no state"); return Ok(()); @@ -443,14 +451,15 @@ async fn handle_left_room( let leave_shortstatekey = services .rooms .short - .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; + .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str()) + .await; left_state_ids.insert(leave_shortstatekey, left_event_id); let mut i: u8 = 0; for (key, id) in left_state_ids { if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key)?; + let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key).await?; if !lazy_load_enabled || event_type != StateEventType::RoomMember @@ -458,7 +467,7 @@ async fn handle_left_room( // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { error!("Pdu in state not found: {}", id); continue; }; @@ -495,19 +504,25 @@ async fn handle_left_room( async fn process_presence_updates( services: &Services, presence_updates: &mut HashMap, since: u64, syncing_user: &UserId, ) -> Result<()> { + let presence_since = services.presence.presence_since(since); + // Take presence updates - for (user_id, _, presence_bytes) in services.presence.presence_since(since) { + pin_mut!(presence_since); + while let Some((user_id, _, presence_bytes)) = presence_since.next().await { if !services .rooms .state_cache - .user_sees_user(syncing_user, &user_id)? + .user_sees_user(syncing_user, &user_id) + .await { continue; } let presence_event = services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; + .from_json_bytes_to_event(&presence_bytes, &user_id) + .await?; + match presence_updates.entry(user_id) { Entry::Vacant(slot) => { slot.insert(presence_event); @@ -551,14 +566,14 @@ async fn load_joined_room( let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); - let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10)?; + let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10).await?; let send_notification_counts = !timeline_pdus.is_empty() || services .rooms .user - .last_notification_read(sender_user, room_id)? - > since; + .last_notification_read(sender_user, room_id) + .await > since; let mut timeline_users = HashSet::new(); for (_, event) in &timeline_pdus { @@ -568,355 +583,384 @@ async fn load_joined_room( services .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) - .await?; + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount); // Database queries: - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - return Err!(Database(error!("Room {room_id} has no state"))); - }; + let current_shortstatehash = services + .rooms + .state + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Database(error!("Room {room_id} has no state"))))?; let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, since)?; + .get_token_shortstatehash(room_id, since) + .await + .ok(); - let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = - if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) { - // No state changes - (Vec::new(), None, None, false, Vec::new()) - } else { - // Calculates joined_member_count, invited_member_count and heroes - let calculate_counts = || { - let joined_member_count = services - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0); - let invited_member_count = services - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0); + let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = if timeline_pdus + .is_empty() + && (since_shortstatehash.is_none() || since_shortstatehash.is_some_and(is_equal_to!(current_shortstatehash))) + { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || async { + let joined_member_count = services + .rooms + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(0); - // Recalculate heroes (first 5 members) - let mut heroes: Vec = Vec::with_capacity(5); + let invited_member_count = services + .rooms + .state_cache + .room_invited_count(room_id) + .await + .unwrap_or(0); - if joined_member_count.saturating_add(invited_member_count) <= 5 { - // Go through all PDUs and for each member event, check if the user is still - // joined or invited until we have 5 or we reach the end + if joined_member_count.saturating_add(invited_member_count) > 5 { + return Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), Vec::new())); + } - for hero in services - .rooms - .timeline - .all_pdus(sender_user, room_id)? - .filter_map(Result::ok) // Ignore all broken pdus - .filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) - .map(|(_, pdu)| { - let content: RoomMemberEventContent = serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - // The membership was and still is invite or join - if matches!(content.membership, MembershipState::Join | MembershipState::Invite) - && (services.rooms.state_cache.is_joined(&user_id, room_id)? - || services.rooms.state_cache.is_invited(&user_id, room_id)?) - { - Ok::<_, Error>(Some(user_id)) - } else { - Ok(None) - } - } else { - Ok(None) - } - }) - .filter_map(Result::ok) - // Filter for possible heroes - .flatten() - { - if heroes.contains(&hero) || hero == sender_user { - continue; - } + // Go through all PDUs and for each member event, check if the user is still + // joined or invited until we have 5 or we reach the end - heroes.push(hero); + // Recalculate heroes (first 5 members) + let heroes = services + .rooms + .timeline + .all_pdus(sender_user, room_id) + .await? + .ready_filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) + .filter_map(|(_, pdu)| async move { + let Ok(content) = serde_json::from_str::(pdu.content.get()) else { + return None; + }; + + let Some(state_key) = &pdu.state_key else { + return None; + }; + + let Ok(user_id) = UserId::parse(state_key) else { + return None; + }; + + if user_id == sender_user { + return None; } - } - Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), heroes)) - }; + // The membership was and still is invite or join + if !matches!(content.membership, MembershipState::Join | MembershipState::Invite) { + return None; + } - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services + if !services .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() + .state_cache + .is_joined(&user_id, room_id) + .await && services + .rooms + .state_cache + .is_invited(&user_id, room_id) + .await + { + return None; + } + + Some(user_id) }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + .collect::>() + .await; + + Ok::<_, Error>(( + Some(joined_member_count), + Some(invited_member_count), + heroes.into_iter().collect::>(), + )) + }; - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + let since_sender_member: Option = if let Some(short) = since_shortstatehash { + services + .rooms + .state_accessor + .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) + .await + .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) + .ok() + } else { + None + }; - if since_shortstatehash.is_none() || joined_since_last_sync { - // Probably since = 0, we will do an initial sync + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; + if since_shortstatehash.is_none() || joined_since_last_sync { + // Probably since = 0, we will do an initial sync - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; + let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?; - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; - let mut i: u8 = 0; - for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services - .rooms - .short - .get_statekey_from_short(shortstatekey)?; + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); - if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - state_events.push(pdu); + let mut i: u8 = 0; + for (shortstatekey, id) in current_state_ids { + let (event_type, state_key) = services + .rooms + .short + .get_statekey_from_short(shortstatekey) + .await?; - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled + if event_type != StateEventType::RoomMember { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; + state_events.push(pdu); + + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } else if !lazy_load_enabled || full_state || timeline_users.contains(&state_key) // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) - { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; + { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; - // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(&state_key) { - lazy_loaded.insert(uid); - } - state_events.push(pdu); + // This check is in case a bad user ID made it into the database + if let Ok(uid) = UserId::parse(&state_key) { + lazy_loaded.insert(uid); + } + state_events.push(pdu); - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; } } + } - // Reset lazy loading because this is an initial sync - services - .rooms - .lazy_loading - .lazy_load_reset(sender_user, sender_device, room_id)?; + // Reset lazy loading because this is an initial sync + services + .rooms + .lazy_loading + .lazy_load_reset(sender_user, sender_device, room_id) + .await; + + // The state_events above should contain all timeline_users, let's mark them as + // lazy loaded. + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); - // The state_events above should contain all timeline_users, let's mark them as - // lazy loaded. - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; + (heroes, joined_member_count, invited_member_count, true, state_events) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.expect("missing since_shortstatehash on incremental sync"); - (heroes, joined_member_count, invited_member_count, true, state_events) - } else { - // Incremental /sync - let since_shortstatehash = since_shortstatehash.unwrap(); + let mut delta_state_events = Vec::new(); - let mut delta_state_events = Vec::new(); + if since_shortstatehash != current_shortstatehash { + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; - if since_shortstatehash != current_shortstatehash { - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; + let since_state_ids = services + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; - for (key, id) in current_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; + for (key, id) in current_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; - delta_state_events.push(pdu); - tokio::task::yield_now().await; - } + delta_state_events.push(pdu); + tokio::task::yield_now().await; } } + } - let encrypted_room = services - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + let encrypted_room = services + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; - // Calculations: - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_err(); - let send_member_count = delta_state_events - .iter() - .any(|event| event.kind == TimelineEventType::RoomMember); + let send_member_count = delta_state_events + .iter() + .any(|event| event.kind == TimelineEventType::RoomMember); - if encrypted_room { - for state_event in &delta_state_events { - if state_event.kind != TimelineEventType::RoomMember { - continue; - } + if encrypted_room { + for state_event in &delta_state_events { + if state_event.kind != TimelineEventType::RoomMember { + continue; + } - if let Some(state_key) = &state_event.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + if let Some(state_key) = &state_event.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - if user_id == sender_user { - continue; - } + if user_id == sender_user { + continue; + } - let new_membership = - serde_json::from_str::(state_event.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; + let new_membership = serde_json::from_str::(state_event.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .membership; - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(services, sender_user, &user_id, room_id)? { - device_list_updates.insert(user_id); - } - }, - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - }, - _ => {}, - } + match new_membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(services, sender_user, &user_id, Some(room_id)).await { + device_list_updates.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, } } } + } - if joined_since_last_sync && encrypted_room || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_updates.extend( - services - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - !share_encrypted_room(services, sender_user, user_id, room_id).unwrap_or(false) - }), - ); - } + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_updates.extend( + services + .rooms + .state_cache + .room_members(room_id) + .ready_filter(|user_id| { + // Don't send key updates from the sender to the sender + sender_user != *user_id + }) + .filter_map(|user_id| async move { + // Only send keys if the sender doesn't share an encrypted room with the target + // already + (!share_encrypted_room(services, sender_user, user_id, Some(room_id)).await) + .then_some(user_id.to_owned()) + }) + .collect::>() + .await, + ); + } - let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts()? - } else { - (None, None, Vec::new()) - }; + let (joined_member_count, invited_member_count, heroes) = if send_member_count { + calculate_counts().await? + } else { + (None, None, Vec::new()) + }; - let mut state_events = delta_state_events; - let mut lazy_loaded = HashSet::new(); - - // Mark all member events we're returning as lazy-loaded - for pdu in &state_events { - if pdu.kind == TimelineEventType::RoomMember { - match UserId::parse( - pdu.state_key - .as_ref() - .expect("State event has state key") - .clone(), - ) { - Ok(state_key_userid) => { - lazy_loaded.insert(state_key_userid); - }, - Err(e) => error!("Invalid state key for member event: {}", e), - } + let mut state_events = delta_state_events; + let mut lazy_loaded = HashSet::new(); + + // Mark all member events we're returning as lazy-loaded + for pdu in &state_events { + if pdu.kind == TimelineEventType::RoomMember { + match UserId::parse( + pdu.state_key + .as_ref() + .expect("State event has state key") + .clone(), + ) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + }, + Err(e) => error!("Invalid state key for member event: {}", e), } } + } - // Fetch contextual member state events for events from the timeline, and - // mark them as lazy-loaded as well. - for (_, event) in &timeline_pdus { - if lazy_loaded.contains(&event.sender) { - continue; - } + // Fetch contextual member state events for events from the timeline, and + // mark them as lazy-loaded as well. + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { + continue; + } - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant + { + if let Ok(member_event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, event.sender.as_str()) + .await { - if let Some(member_event) = services.rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomMember, - event.sender.as_str(), - )? { - lazy_loaded.insert(event.sender.clone()); - state_events.push(member_event); - } + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); } } + } - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); - ( - heroes, - joined_member_count, - invited_member_count, - joined_since_last_sync, - state_events, - ) - } - }; + ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) + } + }; // Look for device list updates in this room device_list_updates.extend( services .users .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); let notification_count = if send_notification_counts { @@ -924,7 +968,8 @@ async fn load_joined_room( services .rooms .user - .notification_count(sender_user, room_id)? + .notification_count(sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ) @@ -937,7 +982,8 @@ async fn load_joined_room( services .rooms .user - .highlight_count(sender_user, room_id)? + .highlight_count(sender_user, room_id) + .await .try_into() .expect("highlight count can't go that high"), ) @@ -966,9 +1012,9 @@ async fn load_joined_room( .rooms .read_receipt .readreceipts_since(room_id, since) - .filter_map(Result::ok) // Filter out buggy events .map(|(_, _, v)| v) - .collect(); + .collect() + .await; if services.rooms.typing.last_typing_update(room_id).await? > since { edus.push( @@ -985,13 +1031,15 @@ async fn load_joined_room( services .rooms .user - .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; + .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash) + .await; Ok(JoinedRoom { account_data: RoomAccountData { events: services .account_data - .changes_since(Some(room_id), sender_user, since)? + .changes_since(Some(room_id), sender_user, since) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), @@ -1023,41 +1071,37 @@ async fn load_joined_room( }) } -fn load_timeline( +async fn load_timeline( services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let timeline_pdus; let limited = if services .rooms .timeline - .last_timeline_count(sender_user, room_id)? + .last_timeline_count(sender_user, room_id) + .await? > roomsincecount { let mut non_timeline_pdus = services .rooms .timeline - .pdus_until(sender_user, room_id, PduCount::max())? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .take_while(|(pducount, _)| pducount > &roomsincecount); + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .ready_take_while(|(pducount, _)| pducount > &roomsincecount); // Take the last events for the timeline timeline_pdus = non_timeline_pdus .by_ref() .take(usize_from_u64_truncated(limit)) .collect::>() + .await .into_iter() .rev() .collect::>(); // They /sync response doesn't always return all messages, so we say the output // is limited unless there are events in non_timeline_pdus - non_timeline_pdus.next().is_some() + non_timeline_pdus.next().await.is_some() } else { timeline_pdus = Vec::new(); false @@ -1065,26 +1109,23 @@ fn load_timeline( Ok((timeline_pdus, limited)) } -fn share_encrypted_room( - services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, -) -> Result { - Ok(services +async fn share_encrypted_room( + services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: Option<&RoomId>, +) -> bool { + services .rooms .user - .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? - .filter_map(Result::ok) - .filter(|room_id| room_id != ignore_room) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) + .get_shared_rooms(sender_user, user_id) + .ready_filter(|&room_id| Some(room_id) != ignore_room) + .any(|other_room_id| async move { + services + .rooms + .state_accessor + .room_state_get(other_room_id, &StateEventType::RoomEncryption, "") + .await + .is_ok() }) - .any(|encrypted| encrypted)) + .await } /// POST `/_matrix/client/unstable/org.matrix.msc3575/sync` @@ -1114,7 +1155,7 @@ pub(crate) async fn sync_events_v4_route( if globalsince != 0 && !services - .users + .sync .remembered(sender_user.clone(), sender_device.clone(), conn_id.clone()) { debug!("Restarting sync stream because it was gone from the database"); @@ -1127,41 +1168,43 @@ pub(crate) async fn sync_events_v4_route( if globalsince == 0 { services - .users + .sync .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); } // Get sticky parameters from cache let known_rooms = services - .users + .sync .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); let all_joined_rooms = services .rooms .state_cache .rooms_joined(&sender_user) - .filter_map(Result::ok) - .collect::>(); + .map(ToOwned::to_owned) + .collect::>() + .await; let all_invited_rooms = services .rooms .state_cache .rooms_invited(&sender_user) - .filter_map(Result::ok) .map(|r| r.0) - .collect::>(); + .collect::>() + .await; let all_rooms = all_joined_rooms .iter() - .cloned() - .chain(all_invited_rooms.clone()) + .chain(all_invited_rooms.iter()) + .map(Clone::clone) .collect(); if body.extensions.to_device.enabled.unwrap_or(false) { services .users - .remove_to_device_events(&sender_user, &sender_device, globalsince)?; + .remove_to_device_events(&sender_user, &sender_device, globalsince) + .await; } let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in @@ -1179,7 +1222,8 @@ pub(crate) async fn sync_events_v4_route( if body.extensions.account_data.enabled.unwrap_or(false) { account_data.global = services .account_data - .changes_since(None, &sender_user, globalsince)? + .changes_since(None, &sender_user, globalsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) .collect(); @@ -1190,7 +1234,8 @@ pub(crate) async fn sync_events_v4_route( room.clone(), services .account_data - .changes_since(Some(&room), &sender_user, globalsince)? + .changes_since(Some(&room), &sender_user, globalsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), @@ -1205,40 +1250,42 @@ pub(crate) async fn sync_events_v4_route( services .users .keys_changed(sender_user.as_ref(), globalsince, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); for room_id in &all_joined_rooms { - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - error!("Room {} has no state", room_id); + let Ok(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id).await else { + error!("Room {room_id} has no state"); continue; }; let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, globalsince)?; + .get_token_shortstatehash(room_id, globalsince) + .await + .ok(); - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services - .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + let since_sender_member: Option = if let Some(short) = since_shortstatehash { + services + .rooms + .state_accessor + .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) + .await + .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) + .ok() + } else { + None + }; let encrypted_room = services .rooms .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); if let Some(since_shortstatehash) = since_shortstatehash { // Skip if there are only timeline changes @@ -1246,22 +1293,24 @@ pub(crate) async fn sync_events_v4_route( continue; } - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; let joined_since_last_sync = since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + let new_encrypted_room = encrypted_room && since_encryption.is_err(); + if encrypted_room { let current_state_ids = services .rooms .state_accessor .state_full_ids(current_shortstatehash) .await?; + let since_state_ids = services .rooms .state_accessor @@ -1270,8 +1319,8 @@ pub(crate) async fn sync_events_v4_route( for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; if pdu.kind == TimelineEventType::RoomMember { @@ -1291,7 +1340,9 @@ pub(crate) async fn sync_events_v4_route( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&services, &sender_user, &user_id, room_id)? { + if !share_encrypted_room(&services, &sender_user, &user_id, Some(room_id)) + .await + { device_list_changes.insert(user_id); } }, @@ -1306,22 +1357,25 @@ pub(crate) async fn sync_events_v4_route( } } if joined_since_last_sync || new_encrypted_room { + let sender_user = &sender_user; // If the user is in a new encrypted room, give them all joined users device_list_changes.extend( services .rooms .state_cache .room_members(room_id) - .flatten() - .filter(|user_id| { + .ready_filter(|user_id| { // Don't send key updates from the sender to the sender - &sender_user != user_id + sender_user != user_id }) - .filter(|user_id| { + .filter_map(|user_id| async move { // Only send keys if the sender doesn't share an encrypted room with the target // already - !share_encrypted_room(&services, &sender_user, user_id, room_id).unwrap_or(false) - }), + (!share_encrypted_room(&services, sender_user, user_id, Some(room_id)).await) + .then_some(user_id.to_owned()) + }) + .collect::>() + .await, ); } } @@ -1331,26 +1385,15 @@ pub(crate) async fn sync_events_v4_route( services .users .keys_changed(room_id.as_ref(), globalsince, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); } + for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); + let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + // If the user doesn't share an encrypted room with the target anymore, we need // to tell them if dont_share_encrypted_room { @@ -1362,7 +1405,7 @@ pub(crate) async fn sync_events_v4_route( let mut lists = BTreeMap::new(); let mut todo_rooms = BTreeMap::new(); // and required state - for (list_id, list) in body.lists { + for (list_id, list) in &body.lists { let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) { Some(true) => &all_invited_rooms, Some(false) => &all_joined_rooms, @@ -1371,23 +1414,23 @@ pub(crate) async fn sync_events_v4_route( let active_rooms = match list.filters.clone().map(|f| f.not_room_types) { Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(active_rooms, State(services), &value, true), + Some(value) => filter_rooms(active_rooms, State(services), &value, true).await, None => active_rooms.clone(), }; let active_rooms = match list.filters.clone().map(|f| f.room_types) { Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(&active_rooms, State(services), &value, false), + Some(value) => filter_rooms(&active_rooms, State(services), &value, false).await, None => active_rooms, }; let mut new_known_rooms = BTreeSet::new(); + let ranges = list.ranges.clone(); lists.insert( list_id.clone(), sync_events::v4::SyncList { - ops: list - .ranges + ops: ranges .into_iter() .map(|mut r| { r.0 = r.0.clamp( @@ -1396,29 +1439,34 @@ pub(crate) async fn sync_events_v4_route( ); r.1 = r.1.clamp(r.0, UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX)); + let room_ids = if !active_rooms.is_empty() { active_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec() } else { Vec::new() }; + new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { let todo_room = todo_rooms .entry(room_id.clone()) .or_insert((BTreeSet::new(), 0, u64::MAX)); + let limit = list .room_details .timeline_limit .map_or(10, u64::from) .min(100); + todo_room .0 .extend(list.room_details.required_state.iter().cloned()); + todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date todo_room.2 = todo_room.2.min( known_rooms - .get(&list_id) + .get(list_id.as_str()) .and_then(|k| k.get(room_id)) .copied() .unwrap_or(0), @@ -1438,11 +1486,11 @@ pub(crate) async fn sync_events_v4_route( ); if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( + services.sync.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), - list_id, + list_id.clone(), new_known_rooms, globalsince, ); @@ -1451,7 +1499,7 @@ pub(crate) async fn sync_events_v4_route( let mut known_subscription_rooms = BTreeSet::new(); for (room_id, room) in &body.room_subscriptions { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { continue; } let todo_room = todo_rooms @@ -1477,7 +1525,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( + services.sync.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1488,7 +1536,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services.users.update_sync_subscriptions( + services.sync.update_sync_subscriptions( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1509,12 +1557,13 @@ pub(crate) async fn sync_events_v4_route( .rooms .state_cache .invite_state(&sender_user, room_id) - .unwrap_or(None); + .await + .ok(); (timeline_pdus, limited) = (Vec::new(), true); } else { (timeline_pdus, limited) = - match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit) { + match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit).await { Ok(value) => value, Err(err) => { warn!("Encountered missing timeline in {}, error {}", room_id, err); @@ -1527,17 +1576,20 @@ pub(crate) async fn sync_events_v4_route( room_id.clone(), services .account_data - .changes_since(Some(room_id), &sender_user, *roomsince)? + .changes_since(Some(room_id), &sender_user, *roomsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), ); - let room_receipts = services + let vector: Vec<_> = services .rooms .read_receipt - .readreceipts_since(room_id, *roomsince); - let vector: Vec<_> = room_receipts.into_iter().collect(); + .readreceipts_since(room_id, *roomsince) + .collect() + .await; + let receipt_size = vector.len(); receipts .rooms @@ -1584,41 +1636,42 @@ pub(crate) async fn sync_events_v4_route( let required_state = required_state_request .iter() - .map(|state| { + .stream() + .filter_map(|state| async move { services .rooms .state_accessor .room_state_get(room_id, &state.0, &state.1) + .await + .map(|s| s.to_sync_state_event()) + .ok() }) - .filter_map(Result::ok) - .flatten() - .map(|state| state.to_sync_state_event()) - .collect(); + .collect() + .await; // Heroes let heroes = services .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member != &sender_user) - .map(|member| { - Ok::<_, Error>( - services - .rooms - .state_accessor - .get_member(room_id, &member)? - .map(|memberevent| SlidingSyncRoomHero { - user_id: member, - name: memberevent.displayname, - avatar: memberevent.avatar_url, - }), - ) + .ready_filter(|member| member != &sender_user) + .filter_map(|member| async move { + services + .rooms + .state_accessor + .get_member(room_id, member) + .await + .map(|memberevent| SlidingSyncRoomHero { + user_id: member.to_owned(), + name: memberevent.displayname, + avatar: memberevent.avatar_url, + }) + .ok() }) - .filter_map(Result::ok) - .flatten() .take(5) - .collect::>(); + .collect::>() + .await; + let name = match heroes.len().cmp(&(1_usize)) { Ordering::Greater => { let firsts = heroes[1..] @@ -1626,10 +1679,12 @@ pub(crate) async fn sync_events_v4_route( .map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string())) .collect::>() .join(", "); + let last = heroes[0] .name .clone() .unwrap_or_else(|| heroes[0].user_id.to_string()); + Some(format!("{firsts} and {last}")) }, Ordering::Equal => Some( @@ -1650,11 +1705,17 @@ pub(crate) async fn sync_events_v4_route( rooms.insert( room_id.clone(), sync_events::v4::SlidingSyncRoom { - name: services.rooms.state_accessor.get_name(room_id)?.or(name), + name: services + .rooms + .state_accessor + .get_name(room_id) + .await + .ok() + .or(name), avatar: if let Some(heroes_avatar) = heroes_avatar { ruma::JsOption::Some(heroes_avatar) } else { - match services.rooms.state_accessor.get_avatar(room_id)? { + match services.rooms.state_accessor.get_avatar(room_id).await { ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), ruma::JsOption::Null => ruma::JsOption::Null, ruma::JsOption::Undefined => ruma::JsOption::Undefined, @@ -1668,7 +1729,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .highlight_count(&sender_user, room_id)? + .highlight_count(&sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ), @@ -1676,7 +1738,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .notification_count(&sender_user, room_id)? + .notification_count(&sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ), @@ -1689,7 +1752,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .state_cache - .room_joined_count(room_id)? + .room_joined_count(room_id) + .await .unwrap_or(0) .try_into() .unwrap_or_else(|_| uint!(0)), @@ -1698,7 +1762,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .state_cache - .room_invited_count(room_id)? + .room_invited_count(room_id) + .await .unwrap_or(0) .try_into() .unwrap_or_else(|_| uint!(0)), @@ -1732,7 +1797,9 @@ pub(crate) async fn sync_events_v4_route( Some(sync_events::v4::ToDevice { events: services .users - .get_to_device_events(&sender_user, &sender_device)?, + .get_to_device_events(&sender_user, &sender_device) + .collect() + .await, next_batch: next_batch.to_string(), }) } else { @@ -1745,7 +1812,8 @@ pub(crate) async fn sync_events_v4_route( }, device_one_time_keys_count: services .users - .count_one_time_keys(&sender_user, &sender_device)?, + .count_one_time_keys(&sender_user, &sender_device) + .await, // Fallback keys are not yet supported device_unused_fallback_key_types: None, }, @@ -1759,25 +1827,26 @@ pub(crate) async fn sync_events_v4_route( }) } -fn filter_rooms( +async fn filter_rooms( rooms: &[OwnedRoomId], State(services): State, filter: &[RoomTypeFilter], negate: bool, ) -> Vec { - return rooms + rooms .iter() - .filter(|r| match services.rooms.state_accessor.get_room_type(r) { - Err(e) => { - warn!("Requested room type for {}, but could not retrieve with error {}", r, e); - false - }, - Ok(result) => { - let result = RoomTypeFilter::from(result); - if negate { - !filter.contains(&result) - } else { - filter.is_empty() || filter.contains(&result) - } - }, + .stream() + .filter_map(|r| async move { + match services.rooms.state_accessor.get_room_type(r).await { + Err(_) => false, + Ok(result) => { + let result = RoomTypeFilter::from(Some(result)); + if negate { + !filter.contains(&result) + } else { + filter.is_empty() || filter.contains(&result) + } + }, + } + .then_some(r.to_owned()) }) - .cloned() - .collect(); + .collect() + .await } diff --git a/src/api/client/tag.rs b/src/api/client/tag.rs index 301568e50..bcd0f8170 100644 --- a/src/api/client/tag.rs +++ b/src/api/client/tag.rs @@ -23,10 +23,11 @@ pub(crate) async fn update_tag_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), @@ -41,12 +42,15 @@ pub(crate) async fn update_tag_route( .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(create_tag::v3::Response {}) } @@ -63,10 +67,11 @@ pub(crate) async fn delete_tag_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), @@ -78,12 +83,15 @@ pub(crate) async fn delete_tag_route( tags_event.content.tags.remove(&body.tag.clone().into()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(delete_tag::v3::Response {}) } @@ -100,10 +108,11 @@ pub(crate) async fn get_tags_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 8100f0e67..50f6cdfb2 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::PduEvent; +use futures::StreamExt; use ruma::{ api::client::{error::ErrorKind, threads::get_threads}, uint, @@ -27,20 +29,23 @@ pub(crate) async fn get_threads_route( u64::MAX }; - let threads = services + let room_id = &body.room_id; + let threads: Vec<(u64, PduEvent)> = services .rooms .threads - .threads_until(sender_user, &body.room_id, from, &body.include)? + .threads_until(sender_user, &body.room_id, from, &body.include) + .await? .take(limit) - .filter_map(Result::ok) - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect::>(); + .collect() + .await; let next_batch = threads.last().map(|(count, _)| count.to_string()); diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 1f557ad7b..2b37a9ec5 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -24,8 +25,9 @@ pub(crate) async fn send_event_to_device_route( // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - .is_some() + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await + .is_ok() { return Ok(send_event_to_device::v3::Response {}); } @@ -53,31 +55,35 @@ pub(crate) async fn send_event_to_device_route( continue; } + let event_type = &body.event_type.to_string(); + + let event = event + .deserialize_as() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?; + match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; + services + .users + .add_to_device_event(sender_user, target_user_id, target_device_id, event_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - sender_user, - target_user_id, - &target_device_id?, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; - } + let (event_type, event) = (&event_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender_user, + target_user_id, + target_device_id, + event_type, + event.clone(), + ) + }) + .await; }, } } @@ -86,7 +92,7 @@ pub(crate) async fn send_event_to_device_route( // Save transaction id with empty data services .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; + .add_txnid(sender_user, sender_device, &body.txn_id, &[]); Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client/typing.rs b/src/api/client/typing.rs index a06648e05..932d221ed 100644 --- a/src/api/client/typing.rs +++ b/src/api/client/typing.rs @@ -16,7 +16,8 @@ pub(crate) async fn create_typing_event_route( if !services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "You are not in this room.")); } diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index ab4703fdb..dc570295c 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -2,7 +2,8 @@ use std::collections::BTreeMap; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{warn, Err}; +use conduit::Err; +use futures::StreamExt; use ruma::{ api::{ client::{ @@ -45,7 +46,7 @@ pub(crate) async fn get_mutual_rooms_route( )); } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { return Ok(mutual_rooms::unstable::Response { joined: vec![], next_batch_token: None, @@ -55,9 +56,10 @@ pub(crate) async fn get_mutual_rooms_route( let mutual_rooms: Vec = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - .filter_map(Result::ok) - .collect(); + .get_shared_rooms(sender_user, &body.user_id) + .map(ToOwned::to_owned) + .collect() + .await; Ok(mutual_rooms::unstable::Response { joined: mutual_rooms, @@ -99,7 +101,7 @@ pub(crate) async fn get_room_summary( let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?; - if !services.rooms.metadata.exists(&room_id)? { + if !services.rooms.metadata.exists(&room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } @@ -108,7 +110,7 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -122,50 +124,58 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_canonical_alias(&room_id) - .unwrap_or(None), + .await + .ok(), avatar_url: services .rooms .state_accessor - .get_avatar(&room_id)? + .get_avatar(&room_id) + .await .into_option() .unwrap_or_default() .url, - guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?, - name: services - .rooms - .state_accessor - .get_name(&room_id) - .unwrap_or(None), + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), num_joined_members: services .rooms .state_cache .room_joined_count(&room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) - .try_into() - .expect("user count should not be that big"), + .await + .unwrap_or(0) + .try_into()?, topic: services .rooms .state_accessor .get_room_topic(&room_id) - .unwrap_or(None), + .await + .ok(), world_readable: services .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false), - join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0, - room_type: services.rooms.state_accessor.get_room_type(&room_id)?, - room_version: Some(services.rooms.state.get_room_version(&room_id)?), + .await, + join_rule: services + .rooms + .state_accessor + .get_join_rule(&room_id) + .await + .unwrap_or_default() + .0, + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_version: services.rooms.state.get_room_version(&room_id).await.ok(), membership: if let Some(sender_user) = sender_user { services .rooms .state_accessor - .get_member(&room_id, sender_user)? - .map_or_else(|| Some(MembershipState::Leave), |content| Some(content.membership)) + .get_member(&room_id, sender_user) + .await + .map_or_else(|_| MembershipState::Leave, |content| content.membership) + .into() } else { None }, @@ -173,7 +183,8 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_room_encryption(&room_id) - .unwrap_or_else(|_e| None), + .await + .ok(), }) } @@ -191,13 +202,14 @@ pub(crate) async fn delete_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services.users.set_timezone(&body.user_id, None).await?; + services.users.set_timezone(&body.user_id, None); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_timezone_key::unstable::Response {}) @@ -217,16 +229,14 @@ pub(crate) async fn set_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services - .users - .set_timezone(&body.user_id, body.tz.clone()) - .await?; + services.users.set_timezone(&body.user_id, body.tz.clone()); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_timezone_key::unstable::Response {}) @@ -280,10 +290,11 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), &all_joined_rooms).await?; } else if body.key == "avatar_url" { let mxc = ruma::OwnedMxcUri::from(profile_key_value.to_string()); @@ -291,21 +302,23 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, Some(mxc), None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, Some(mxc), None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone())); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_profile_key::unstable::Response {}) @@ -335,30 +348,33 @@ pub(crate) async fn delete_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, None, all_joined_rooms).await?; + update_displayname(&services, &body.user_id, None, &all_joined_rooms).await?; } else if body.key == "avatar_url" { let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, None, None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, None, None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, None)?; + .set_profile_key(&body.user_id, &body.key, None); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_profile_key::unstable::Response {}) @@ -386,26 +402,25 @@ pub(crate) async fn get_timezone_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); return Ok(get_timezone_key::unstable::Response { tz: response.tz, @@ -413,14 +428,14 @@ pub(crate) async fn get_timezone_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_timezone_key::unstable::Response { - tz: services.users.timezone(&body.user_id)?, + tz: services.users.timezone(&body.user_id).await.ok(), }) } @@ -448,32 +463,31 @@ pub(crate) async fn get_profile_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); if let Some(value) = response.custom_profile_fields.get(&body.key) { profile_key_value.insert(body.key.clone(), value.clone()); services .users - .set_profile_key(&body.user_id, &body.key, Some(value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(value.clone())); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); } @@ -484,13 +498,13 @@ pub(crate) async fn get_profile_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation - return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); + return Err!(Request(NotFound("Profile was not found."))); } - if let Some(value) = services.users.profile_key(&body.user_id, &body.key)? { + if let Ok(value) = services.users.profile_key(&body.user_id, &body.key).await { profile_key_value.insert(body.key.clone(), value); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); diff --git a/src/api/client/unversioned.rs b/src/api/client/unversioned.rs index d714fda54..d5bb14e5d 100644 --- a/src/api/client/unversioned.rs +++ b/src/api/client/unversioned.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use axum::{extract::State, response::IntoResponse, Json}; +use futures::StreamExt; use ruma::api::client::{ discovery::{ discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo}, @@ -173,7 +174,7 @@ pub(crate) async fn conduwuit_server_version() -> Result { /// homeserver. Endpoint is disabled if federation is disabled for privacy. This /// only includes active users (not deactivated, no guests, etc) pub(crate) async fn conduwuit_local_user_count(State(services): State) -> Result { - let user_count = services.users.list_local_users()?.len(); + let user_count = services.users.list_local_users().count().await; Ok(Json(serde_json::json!({ "count": user_count diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index 87d4062cd..8ea7f1b82 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use futures::{pin_mut, StreamExt}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -21,14 +22,12 @@ pub(crate) async fn search_users_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 - let mut users = services.users.iter().filter_map(|user_id| { + let users = services.users.stream().filter_map(|user_id| async { // Filter out buggy users (they should not exist, but you never know...) - let user_id = user_id.ok()?; - let user = search_users::v3::User { - user_id: user_id.clone(), - display_name: services.users.displayname(&user_id).ok()?, - avatar_url: services.users.avatar_url(&user_id).ok()?, + user_id: user_id.to_owned(), + display_name: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), }; let user_id_matches = user @@ -56,20 +55,19 @@ pub(crate) async fn search_users_route( let user_is_in_public_rooms = services .rooms .state_cache - .rooms_joined(&user_id) - .filter_map(Result::ok) - .any(|room| { + .rooms_joined(&user.user_id) + .any(|room| async move { services .rooms .state_accessor - .room_state_get(&room, &StateEventType::RoomJoinRules, "") + .room_state_get(room, &StateEventType::RoomJoinRules, "") + .await .map_or(false, |event| { - event.map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) - }) + serde_json::from_str(event.content.get()) + .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) }) - }); + }) + .await; if user_is_in_public_rooms { user_visible = true; @@ -77,25 +75,22 @@ pub(crate) async fn search_users_route( let user_is_in_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), user_id]) - .ok()? - .next() - .is_some(); + .has_shared_rooms(sender_user, &user.user_id) + .await; if user_is_in_shared_rooms { user_visible = true; } } - if !user_visible { - return None; - } - - Some(user) + user_visible.then_some(user) }); - let results = users.by_ref().take(limit).collect(); - let limited = users.next().is_some(); + pin_mut!(users); + + let limited = users.by_ref().next().await.is_some(); + + let results = users.take(limit).collect().await; Ok(search_users::v3::Response { results, diff --git a/src/api/router.rs b/src/api/router.rs index 4264e01df..c4275f054 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -22,101 +22,101 @@ use crate::{client, server}; pub fn build(router: Router, server: &Server) -> Router { let config = &server.config; let mut router = router - .ruma_route(client::get_timezone_key_route) - .ruma_route(client::get_profile_key_route) - .ruma_route(client::set_profile_key_route) - .ruma_route(client::delete_profile_key_route) - .ruma_route(client::set_timezone_key_route) - .ruma_route(client::delete_timezone_key_route) - .ruma_route(client::appservice_ping) - .ruma_route(client::get_supported_versions_route) - .ruma_route(client::get_register_available_route) - .ruma_route(client::register_route) - .ruma_route(client::get_login_types_route) - .ruma_route(client::login_route) - .ruma_route(client::whoami_route) - .ruma_route(client::logout_route) - .ruma_route(client::logout_all_route) - .ruma_route(client::change_password_route) - .ruma_route(client::deactivate_route) - .ruma_route(client::third_party_route) - .ruma_route(client::request_3pid_management_token_via_email_route) - .ruma_route(client::request_3pid_management_token_via_msisdn_route) - .ruma_route(client::check_registration_token_validity) - .ruma_route(client::get_capabilities_route) - .ruma_route(client::get_pushrules_all_route) - .ruma_route(client::set_pushrule_route) - .ruma_route(client::get_pushrule_route) - .ruma_route(client::set_pushrule_enabled_route) - .ruma_route(client::get_pushrule_enabled_route) - .ruma_route(client::get_pushrule_actions_route) - .ruma_route(client::set_pushrule_actions_route) - .ruma_route(client::delete_pushrule_route) - .ruma_route(client::get_room_event_route) - .ruma_route(client::get_room_aliases_route) - .ruma_route(client::get_filter_route) - .ruma_route(client::create_filter_route) - .ruma_route(client::create_openid_token_route) - .ruma_route(client::set_global_account_data_route) - .ruma_route(client::set_room_account_data_route) - .ruma_route(client::get_global_account_data_route) - .ruma_route(client::get_room_account_data_route) - .ruma_route(client::set_displayname_route) - .ruma_route(client::get_displayname_route) - .ruma_route(client::set_avatar_url_route) - .ruma_route(client::get_avatar_url_route) - .ruma_route(client::get_profile_route) - .ruma_route(client::set_presence_route) - .ruma_route(client::get_presence_route) - .ruma_route(client::upload_keys_route) - .ruma_route(client::get_keys_route) - .ruma_route(client::claim_keys_route) - .ruma_route(client::create_backup_version_route) - .ruma_route(client::update_backup_version_route) - .ruma_route(client::delete_backup_version_route) - .ruma_route(client::get_latest_backup_info_route) - .ruma_route(client::get_backup_info_route) - .ruma_route(client::add_backup_keys_route) - .ruma_route(client::add_backup_keys_for_room_route) - .ruma_route(client::add_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_for_room_route) - .ruma_route(client::delete_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_route) - .ruma_route(client::get_backup_keys_for_room_route) - .ruma_route(client::get_backup_keys_for_session_route) - .ruma_route(client::get_backup_keys_route) - .ruma_route(client::set_read_marker_route) - .ruma_route(client::create_receipt_route) - .ruma_route(client::create_typing_event_route) - .ruma_route(client::create_room_route) - .ruma_route(client::redact_event_route) - .ruma_route(client::report_event_route) - .ruma_route(client::create_alias_route) - .ruma_route(client::delete_alias_route) - .ruma_route(client::get_alias_route) - .ruma_route(client::join_room_by_id_route) - .ruma_route(client::join_room_by_id_or_alias_route) - .ruma_route(client::joined_members_route) - .ruma_route(client::leave_room_route) - .ruma_route(client::forget_room_route) - .ruma_route(client::joined_rooms_route) - .ruma_route(client::kick_user_route) - .ruma_route(client::ban_user_route) - .ruma_route(client::unban_user_route) - .ruma_route(client::invite_user_route) - .ruma_route(client::set_room_visibility_route) - .ruma_route(client::get_room_visibility_route) - .ruma_route(client::get_public_rooms_route) - .ruma_route(client::get_public_rooms_filtered_route) - .ruma_route(client::search_users_route) - .ruma_route(client::get_member_events_route) - .ruma_route(client::get_protocols_route) + .ruma_route(&client::get_timezone_key_route) + .ruma_route(&client::get_profile_key_route) + .ruma_route(&client::set_profile_key_route) + .ruma_route(&client::delete_profile_key_route) + .ruma_route(&client::set_timezone_key_route) + .ruma_route(&client::delete_timezone_key_route) + .ruma_route(&client::appservice_ping) + .ruma_route(&client::get_supported_versions_route) + .ruma_route(&client::get_register_available_route) + .ruma_route(&client::register_route) + .ruma_route(&client::get_login_types_route) + .ruma_route(&client::login_route) + .ruma_route(&client::whoami_route) + .ruma_route(&client::logout_route) + .ruma_route(&client::logout_all_route) + .ruma_route(&client::change_password_route) + .ruma_route(&client::deactivate_route) + .ruma_route(&client::third_party_route) + .ruma_route(&client::request_3pid_management_token_via_email_route) + .ruma_route(&client::request_3pid_management_token_via_msisdn_route) + .ruma_route(&client::check_registration_token_validity) + .ruma_route(&client::get_capabilities_route) + .ruma_route(&client::get_pushrules_all_route) + .ruma_route(&client::set_pushrule_route) + .ruma_route(&client::get_pushrule_route) + .ruma_route(&client::set_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_actions_route) + .ruma_route(&client::set_pushrule_actions_route) + .ruma_route(&client::delete_pushrule_route) + .ruma_route(&client::get_room_event_route) + .ruma_route(&client::get_room_aliases_route) + .ruma_route(&client::get_filter_route) + .ruma_route(&client::create_filter_route) + .ruma_route(&client::create_openid_token_route) + .ruma_route(&client::set_global_account_data_route) + .ruma_route(&client::set_room_account_data_route) + .ruma_route(&client::get_global_account_data_route) + .ruma_route(&client::get_room_account_data_route) + .ruma_route(&client::set_displayname_route) + .ruma_route(&client::get_displayname_route) + .ruma_route(&client::set_avatar_url_route) + .ruma_route(&client::get_avatar_url_route) + .ruma_route(&client::get_profile_route) + .ruma_route(&client::set_presence_route) + .ruma_route(&client::get_presence_route) + .ruma_route(&client::upload_keys_route) + .ruma_route(&client::get_keys_route) + .ruma_route(&client::claim_keys_route) + .ruma_route(&client::create_backup_version_route) + .ruma_route(&client::update_backup_version_route) + .ruma_route(&client::delete_backup_version_route) + .ruma_route(&client::get_latest_backup_info_route) + .ruma_route(&client::get_backup_info_route) + .ruma_route(&client::add_backup_keys_route) + .ruma_route(&client::add_backup_keys_for_room_route) + .ruma_route(&client::add_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_for_room_route) + .ruma_route(&client::delete_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_route) + .ruma_route(&client::get_backup_keys_for_room_route) + .ruma_route(&client::get_backup_keys_for_session_route) + .ruma_route(&client::get_backup_keys_route) + .ruma_route(&client::set_read_marker_route) + .ruma_route(&client::create_receipt_route) + .ruma_route(&client::create_typing_event_route) + .ruma_route(&client::create_room_route) + .ruma_route(&client::redact_event_route) + .ruma_route(&client::report_event_route) + .ruma_route(&client::create_alias_route) + .ruma_route(&client::delete_alias_route) + .ruma_route(&client::get_alias_route) + .ruma_route(&client::join_room_by_id_route) + .ruma_route(&client::join_room_by_id_or_alias_route) + .ruma_route(&client::joined_members_route) + .ruma_route(&client::leave_room_route) + .ruma_route(&client::forget_room_route) + .ruma_route(&client::joined_rooms_route) + .ruma_route(&client::kick_user_route) + .ruma_route(&client::ban_user_route) + .ruma_route(&client::unban_user_route) + .ruma_route(&client::invite_user_route) + .ruma_route(&client::set_room_visibility_route) + .ruma_route(&client::get_room_visibility_route) + .ruma_route(&client::get_public_rooms_route) + .ruma_route(&client::get_public_rooms_filtered_route) + .ruma_route(&client::search_users_route) + .ruma_route(&client::get_member_events_route) + .ruma_route(&client::get_protocols_route) .route("/_matrix/client/unstable/thirdparty/protocols", get(client::get_protocols_route_unstable)) - .ruma_route(client::send_message_event_route) - .ruma_route(client::send_state_event_for_key_route) - .ruma_route(client::get_state_events_route) - .ruma_route(client::get_state_events_for_key_route) + .ruma_route(&client::send_message_event_route) + .ruma_route(&client::send_state_event_for_key_route) + .ruma_route(&client::get_state_events_route) + .ruma_route(&client::get_state_events_for_key_route) // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes // share one Ruma request / response type pair with {get,send}_state_event_for_key_route .route( @@ -140,46 +140,46 @@ pub fn build(router: Router, server: &Server) -> Router { get(client::get_state_events_for_empty_key_route) .put(client::send_state_event_for_empty_key_route), ) - .ruma_route(client::sync_events_route) - .ruma_route(client::sync_events_v4_route) - .ruma_route(client::get_context_route) - .ruma_route(client::get_message_events_route) - .ruma_route(client::search_events_route) - .ruma_route(client::turn_server_route) - .ruma_route(client::send_event_to_device_route) - .ruma_route(client::create_content_route) - .ruma_route(client::get_content_thumbnail_route) - .ruma_route(client::get_content_route) - .ruma_route(client::get_content_as_filename_route) - .ruma_route(client::get_media_preview_route) - .ruma_route(client::get_media_config_route) - .ruma_route(client::get_devices_route) - .ruma_route(client::get_device_route) - .ruma_route(client::update_device_route) - .ruma_route(client::delete_device_route) - .ruma_route(client::delete_devices_route) - .ruma_route(client::get_tags_route) - .ruma_route(client::update_tag_route) - .ruma_route(client::delete_tag_route) - .ruma_route(client::upload_signing_keys_route) - .ruma_route(client::upload_signatures_route) - .ruma_route(client::get_key_changes_route) - .ruma_route(client::get_pushers_route) - .ruma_route(client::set_pushers_route) - .ruma_route(client::upgrade_room_route) - .ruma_route(client::get_threads_route) - .ruma_route(client::get_relating_events_with_rel_type_and_event_type_route) - .ruma_route(client::get_relating_events_with_rel_type_route) - .ruma_route(client::get_relating_events_route) - .ruma_route(client::get_hierarchy_route) - .ruma_route(client::get_mutual_rooms_route) - .ruma_route(client::get_room_summary) + .ruma_route(&client::sync_events_route) + .ruma_route(&client::sync_events_v4_route) + .ruma_route(&client::get_context_route) + .ruma_route(&client::get_message_events_route) + .ruma_route(&client::search_events_route) + .ruma_route(&client::turn_server_route) + .ruma_route(&client::send_event_to_device_route) + .ruma_route(&client::create_content_route) + .ruma_route(&client::get_content_thumbnail_route) + .ruma_route(&client::get_content_route) + .ruma_route(&client::get_content_as_filename_route) + .ruma_route(&client::get_media_preview_route) + .ruma_route(&client::get_media_config_route) + .ruma_route(&client::get_devices_route) + .ruma_route(&client::get_device_route) + .ruma_route(&client::update_device_route) + .ruma_route(&client::delete_device_route) + .ruma_route(&client::delete_devices_route) + .ruma_route(&client::get_tags_route) + .ruma_route(&client::update_tag_route) + .ruma_route(&client::delete_tag_route) + .ruma_route(&client::upload_signing_keys_route) + .ruma_route(&client::upload_signatures_route) + .ruma_route(&client::get_key_changes_route) + .ruma_route(&client::get_pushers_route) + .ruma_route(&client::set_pushers_route) + .ruma_route(&client::upgrade_room_route) + .ruma_route(&client::get_threads_route) + .ruma_route(&client::get_relating_events_with_rel_type_and_event_type_route) + .ruma_route(&client::get_relating_events_with_rel_type_route) + .ruma_route(&client::get_relating_events_route) + .ruma_route(&client::get_hierarchy_route) + .ruma_route(&client::get_mutual_rooms_route) + .ruma_route(&client::get_room_summary) .route( "/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary", get(client::get_room_summary_legacy) ) - .ruma_route(client::well_known_support) - .ruma_route(client::well_known_client) + .ruma_route(&client::well_known_support) + .ruma_route(&client::well_known_client) .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) @@ -187,35 +187,35 @@ pub fn build(router: Router, server: &Server) -> Router { if config.allow_federation { router = router - .ruma_route(server::get_server_version_route) + .ruma_route(&server::get_server_version_route) .route("/_matrix/key/v2/server", get(server::get_server_keys_route)) .route("/_matrix/key/v2/server/:key_id", get(server::get_server_keys_deprecated_route)) - .ruma_route(server::get_public_rooms_route) - .ruma_route(server::get_public_rooms_filtered_route) - .ruma_route(server::send_transaction_message_route) - .ruma_route(server::get_event_route) - .ruma_route(server::get_backfill_route) - .ruma_route(server::get_missing_events_route) - .ruma_route(server::get_event_authorization_route) - .ruma_route(server::get_room_state_route) - .ruma_route(server::get_room_state_ids_route) - .ruma_route(server::create_leave_event_template_route) - .ruma_route(server::create_leave_event_v1_route) - .ruma_route(server::create_leave_event_v2_route) - .ruma_route(server::create_join_event_template_route) - .ruma_route(server::create_join_event_v1_route) - .ruma_route(server::create_join_event_v2_route) - .ruma_route(server::create_invite_route) - .ruma_route(server::get_devices_route) - .ruma_route(server::get_room_information_route) - .ruma_route(server::get_profile_information_route) - .ruma_route(server::get_keys_route) - .ruma_route(server::claim_keys_route) - .ruma_route(server::get_openid_userinfo_route) - .ruma_route(server::get_hierarchy_route) - .ruma_route(server::well_known_server) - .ruma_route(server::get_content_route) - .ruma_route(server::get_content_thumbnail_route) + .ruma_route(&server::get_public_rooms_route) + .ruma_route(&server::get_public_rooms_filtered_route) + .ruma_route(&server::send_transaction_message_route) + .ruma_route(&server::get_event_route) + .ruma_route(&server::get_backfill_route) + .ruma_route(&server::get_missing_events_route) + .ruma_route(&server::get_event_authorization_route) + .ruma_route(&server::get_room_state_route) + .ruma_route(&server::get_room_state_ids_route) + .ruma_route(&server::create_leave_event_template_route) + .ruma_route(&server::create_leave_event_v1_route) + .ruma_route(&server::create_leave_event_v2_route) + .ruma_route(&server::create_join_event_template_route) + .ruma_route(&server::create_join_event_v1_route) + .ruma_route(&server::create_join_event_v2_route) + .ruma_route(&server::create_invite_route) + .ruma_route(&server::get_devices_route) + .ruma_route(&server::get_room_information_route) + .ruma_route(&server::get_profile_information_route) + .ruma_route(&server::get_keys_route) + .ruma_route(&server::claim_keys_route) + .ruma_route(&server::get_openid_userinfo_route) + .ruma_route(&server::get_hierarchy_route) + .ruma_route(&server::well_known_server) + .ruma_route(&server::get_content_route) + .ruma_route(&server::get_content_thumbnail_route) .route("/_conduwuit/local_user_count", get(client::conduwuit_local_user_count)); } else { router = router @@ -227,11 +227,11 @@ pub fn build(router: Router, server: &Server) -> Router { if config.allow_legacy_media { router = router - .ruma_route(client::get_media_config_legacy_route) - .ruma_route(client::get_media_preview_legacy_route) - .ruma_route(client::get_content_legacy_route) - .ruma_route(client::get_content_as_filename_legacy_route) - .ruma_route(client::get_content_thumbnail_legacy_route) + .ruma_route(&client::get_media_config_legacy_route) + .ruma_route(&client::get_media_preview_legacy_route) + .ruma_route(&client::get_content_legacy_route) + .ruma_route(&client::get_content_as_filename_legacy_route) + .ruma_route(&client::get_content_thumbnail_legacy_route) .route("/_matrix/media/v1/config", get(client::get_media_config_legacy_legacy_route)) .route("/_matrix/media/v1/upload", post(client::create_content_legacy_route)) .route( diff --git a/src/api/router/args.rs b/src/api/router/args.rs index a3d09dff5..7381a55f5 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -10,7 +10,10 @@ use super::{auth, auth::Auth, request, request::Request}; use crate::{service::appservice::RegistrationInfo, State}; /// Extractor for Ruma request structs -pub(crate) struct Args { +pub(crate) struct Args +where + T: IncomingRequest + Send + Sync + 'static, +{ /// Request struct body pub(crate) body: T, @@ -38,7 +41,7 @@ pub(crate) struct Args { #[async_trait] impl FromRequest for Args where - T: IncomingRequest, + T: IncomingRequest + Send + Sync + 'static, { type Rejection = Error; @@ -57,7 +60,10 @@ where } } -impl Deref for Args { +impl Deref for Args +where + T: IncomingRequest + Send + Sync + 'static, +{ type Target = T; fn deref(&self) -> &Self::Target { &self.body } @@ -67,7 +73,7 @@ fn make_body( services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, ) -> Result where - T: IncomingRequest, + T: IncomingRequest + Send + Sync + 'static, { let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { let user_id = auth.sender_user.clone().unwrap_or_else(|| { @@ -77,15 +83,13 @@ where let uiaa_request = json_body .get("auth") - .and_then(|auth| auth.as_object()) + .and_then(CanonicalJsonValue::as_object) .and_then(|auth| auth.get("session")) - .and_then(|session| session.as_str()) + .and_then(CanonicalJsonValue::as_str) .and_then(|session| { - services.uiaa.get_uiaa_request( - &user_id, - &auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()), - session, - ) + services + .uiaa + .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) }); if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 670f72ba8..8d76b4be8 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -44,8 +44,8 @@ pub(super) async fn auth( let token = if let Some(token) = token { if let Some(reg_info) = services.appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info)) - } else if let Some((user_id, device_id)) = services.users.find_from_token(token)? { - Token::User((user_id, OwnedDeviceId::from(device_id))) + } else if let Ok((user_id, device_id)) = services.users.find_from_token(token).await { + Token::User((user_id, device_id)) } else { Token::Invalid } @@ -98,7 +98,7 @@ pub(super) async fn auth( )) } }, - (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?), + (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info).await?), (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { Ok(Auth { origin: None, @@ -150,7 +150,7 @@ pub(super) async fn auth( } } -fn auth_appservice(services: &Services, request: &Request, info: Box) -> Result { +async fn auth_appservice(services: &Services, request: &Request, info: Box) -> Result { let user_id = request .query .user_id @@ -170,7 +170,7 @@ fn auth_appservice(services: &Services, request: &Request, info: Box { + fn add_route(&'static self, router: Router, path: &str) -> Router; + fn add_routes(&'static self, router: Router) -> Router; +} + pub(in super::super) trait RouterExt { - fn ruma_route(self, handler: H) -> Self + fn ruma_route(self, handler: &'static H) -> Self where H: RumaHandler; } impl RouterExt for Router { - fn ruma_route(self, handler: H) -> Self + fn ruma_route(self, handler: &'static H) -> Self where H: RumaHandler, { @@ -27,34 +31,28 @@ impl RouterExt for Router { } } -pub(in super::super) trait RumaHandler { - fn add_routes(&self, router: Router) -> Router; - - fn add_route(&self, router: Router, path: &str) -> Router; -} - macro_rules! ruma_handler { ( $($tx:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl RumaHandler<($($tx,)* Ruma,)> for Fun + impl RumaHandler<($($tx,)* Ruma,)> for Fun where - Req: IncomingRequest + Send + 'static, - Ret: IntoResponse, - Fut: Future> + Send, - Fun: FnOnce($($tx,)* Ruma,) -> Fut + Clone + Send + Sync + 'static, - $( $tx: FromRequestParts + Send + 'static, )* + Fun: Fn($($tx,)* Ruma,) -> Fut + Send + Sync + 'static, + Fut: Future> + Send, + Req: IncomingRequest + Send + Sync, + Err: IntoResponse + Send, + ::OutgoingResponse: Send, + $( $tx: FromRequestParts + Send + Sync + 'static, )* { - fn add_routes(&self, router: Router) -> Router { + fn add_routes(&'static self, router: Router) -> Router { Req::METADATA .history .all_paths() .fold(router, |router, path| self.add_route(router, path)) } - fn add_route(&self, router: Router, path: &str) -> Router { - let handle = self.clone(); + fn add_route(&'static self, router: Router, path: &str) -> Router { + let action = |$($tx,)* req| self($($tx,)* req).map_ok(RumaResponse); let method = method_to_filter(&Req::METADATA.method); - let action = |$($tx,)* req| async { handle($($tx,)* req).await.map(RumaResponse) }; router.route(path, on(method, action)) } } diff --git a/src/api/router/response.rs b/src/api/router/response.rs index 2aaa79faa..70bbb9364 100644 --- a/src/api/router/response.rs +++ b/src/api/router/response.rs @@ -5,13 +5,18 @@ use http::StatusCode; use http_body_util::Full; use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; -pub(crate) struct RumaResponse(pub(crate) T); +pub(crate) struct RumaResponse(pub(crate) T) +where + T: OutgoingResponse; impl From for RumaResponse { fn from(t: Error) -> Self { Self(t.into()) } } -impl IntoResponse for RumaResponse { +impl IntoResponse for RumaResponse +where + T: OutgoingResponse, +{ fn into_response(self) -> Response { self.0 .try_into_http_response::() diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 1b665c19d..2bbc95ca9 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -1,9 +1,13 @@ +use std::cmp; + use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::backfill::get_backfill}, - uint, user_id, MilliSecondsSinceUnixEpoch, +use conduit::{ + is_equal_to, + utils::{IterStream, ReadyExt}, + Err, PduCount, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch}; use crate::Ruma; @@ -19,27 +23,35 @@ pub(crate) async fn get_backfill_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let until = body .v .iter() - .map(|event_id| services.rooms.timeline.get_pdu_count(event_id)) - .filter_map(|r| r.ok().flatten()) - .max() - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?; + .stream() + .filter_map(|event_id| { + services + .rooms + .timeline + .get_pdu_count(event_id) + .map(Result::ok) + }) + .ready_fold(PduCount::Backfilled(0), cmp::max) + .await; let limit = body .limit @@ -47,31 +59,37 @@ pub(crate) async fn get_backfill_route( .try_into() .expect("UInt could not be converted to usize"); - let all_events = services + let pdus = services .rooms .timeline - .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? - .take(limit); + .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) + .await? + .take(limit) + .filter_map(|(_, pdu)| async move { + if !services + .rooms + .state_accessor + .server_can_see_event(origin, &pdu.room_id, &pdu.event_id) + .await + .is_ok_and(is_equal_to!(true)) + { + return None; + } - let events = all_events - .filter_map(Result::ok) - .filter(|(_, e)| { - matches!( - services - .rooms - .state_accessor - .server_can_see_event(origin, &e.room_id, &e.event_id,), - Ok(true), - ) + services + .rooms + .timeline + .get_pdu_json(&pdu.event_id) + .await + .ok() }) - .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) - .filter_map(|r| r.ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(); + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await; Ok(get_backfill::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdus: events, + pdus, }) } diff --git a/src/api/server/event.rs b/src/api/server/event.rs index e11a01a20..e4eac794f 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,9 +1,6 @@ use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::event::get_event}, - MilliSecondsSinceUnixEpoch, RoomId, -}; +use conduit::{err, Err, Result}; +use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId}; use crate::Ruma; @@ -21,34 +18,46 @@ pub(crate) async fn get_event_route( let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Event not found."))))?; let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database."))?; + .ok_or_else(|| err!(Database("Invalid event in database.")))?; let room_id = - <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; + <&RoomId>::try_from(room_id_str).map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - if !services.rooms.state_accessor.is_world_readable(room_id)? - && !services.rooms.state_cache.server_in_room(origin, room_id)? + if !services + .rooms + .state_accessor + .is_world_readable(room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } if !services .rooms .state_accessor - .server_can_see_event(origin, room_id, &body.event_id)? + .server_can_see_event(origin, room_id, &body.event_id) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not allowed to see event.")); + return Err!(Request(Forbidden("Server is not allowed to see event."))); } Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: services.sending.convert_to_outgoing_federation_event(event), + pdu: services + .sending + .convert_to_outgoing_federation_event(event) + .await, }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 4b0f6bc00..6ec00b501 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, @@ -22,16 +23,18 @@ pub(crate) async fn get_event_authorization_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); } @@ -39,8 +42,9 @@ pub(crate) async fn get_event_authorization_route( let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; let room_id_str = event .get("room_id") @@ -50,16 +54,17 @@ pub(crate) async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .await? + .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await; Ok(get_event_authorization::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, }) } diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index e2c3c93cf..7ae0ff608 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -18,16 +18,18 @@ pub(crate) async fn get_missing_events_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); } @@ -43,7 +45,12 @@ pub(crate) async fn get_missing_events_route( let mut i: usize = 0; while i < queued_events.len() && events.len() < limit { - if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? { + if let Ok(pdu) = services + .rooms + .timeline + .get_pdu_json(&queued_events[i]) + .await + { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -64,7 +71,8 @@ pub(crate) async fn get_missing_events_route( if !services .rooms .state_accessor - .server_can_see_event(origin, &body.room_id, &queued_events[i])? + .server_can_see_event(origin, &body.room_id, &queued_events[i]) + .await? { i = i.saturating_add(1); continue; @@ -81,7 +89,12 @@ pub(crate) async fn get_missing_events_route( ) .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, ); - events.push(services.sending.convert_to_outgoing_federation_event(pdu)); + events.push( + services + .sending + .convert_to_outgoing_federation_event(pdu) + .await, + ); } i = i.saturating_add(1); } diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index 530ed1456..002bd7633 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -12,7 +12,7 @@ pub(crate) async fn get_hierarchy_route( ) -> Result { let origin = body.origin.as_ref().expect("server is authenticated"); - if services.rooms.metadata.exists(&body.room_id)? { + if services.rooms.metadata.exists(&body.room_id).await { services .rooms .spaces diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 688e026c5..9968bdf72 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -24,7 +24,8 @@ pub(crate) async fn create_invite_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .globals @@ -98,7 +99,8 @@ pub(crate) async fn create_invite_route( services .rooms .event_handler - .acl_check(invited_user.server_name(), &body.room_id)?; + .acl_check(invited_user.server_name(), &body.room_id) + .await?; ruma::signatures::hash_and_sign_event( services.globals.server_name().as_str(), @@ -128,14 +130,14 @@ pub(crate) async fn create_invite_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; - if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? { + if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "This room is banned on this homeserver.", )); } - if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? { + if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "This server does not allow room invites.", @@ -159,22 +161,28 @@ pub(crate) async fn create_invite_route( if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), &body.room_id)? + .server_in_room(services.globals.server_name(), &body.room_id) + .await { - services.rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - RoomMemberEventContent::new(MembershipState::Invite), - &sender, - Some(invite_state), - body.via.clone(), - true, - )?; + services + .rooms + .state_cache + .update_membership( + &body.room_id, + &invited_user, + RoomMemberEventContent::new(MembershipState::Invite), + &sender, + Some(invite_state), + body.via.clone(), + true, + ) + .await?; } Ok(create_invite::v2::Response { event: services .sending - .convert_to_outgoing_federation_event(signed_event), + .convert_to_outgoing_federation_event(signed_event) + .await, }) } diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index 021016be2..ba081aade 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::utils::{IterStream, ReadyExt}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_join_event}, events::{ @@ -24,7 +26,7 @@ use crate::{ pub(crate) async fn create_join_event_template_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -40,7 +42,8 @@ pub(crate) async fn create_join_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if services .globals @@ -73,7 +76,7 @@ pub(crate) async fn create_join_event_template_route( } } - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -81,22 +84,24 @@ pub(crate) async fn create_join_event_template_route( .rooms .state_cache .is_left(&body.user_id, &body.room_id) - .unwrap_or(true)) - && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id)? + .await) + && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id).await? { let auth_user = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(Result::ok) - .filter(|user| user.server_name() == services.globals.server_name()) - .find(|user| { + .ready_filter(|user| user.server_name() == services.globals.server_name()) + .filter(|user| { services .rooms .state_accessor .user_can_invite(&body.room_id, user, &body.user_id, &state_lock) - .unwrap_or(false) - }); + }) + .boxed() + .next() + .await + .map(ToOwned::to_owned); if auth_user.is_some() { auth_user @@ -110,7 +115,7 @@ pub(crate) async fn create_join_event_template_route( None }; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; if !body.ver.contains(&room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { @@ -132,19 +137,23 @@ pub(crate) async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + timestamp: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); @@ -161,7 +170,7 @@ pub(crate) async fn create_join_event_template_route( /// This doesn't check the current user's membership. This should be done /// externally, either by using the state cache or attempting to authorize the /// event. -pub(crate) fn user_can_perform_restricted_join( +pub(crate) async fn user_can_perform_restricted_join( services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result { use RoomVersionId::*; @@ -169,18 +178,15 @@ pub(crate) fn user_can_perform_restricted_join( let join_rules_event = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let Some(join_rules_event_content) = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str::(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event in database: {e}"); - Error::bad_database("Invalid join rules event in database") - }) + .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .await; + + let Ok(Ok(join_rules_event_content)) = join_rules_event.as_ref().map(|join_rules_event| { + serde_json::from_str::(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event in database: {e}"); + Error::bad_database("Invalid join rules event in database") }) - .transpose()? - else { + }) else { return Ok(false); }; @@ -201,13 +207,10 @@ pub(crate) fn user_can_perform_restricted_join( None } }) - .any(|m| { - services - .rooms - .state_cache - .is_joined(user_id, &m.room_id) - .unwrap_or(false) - }) { + .stream() + .any(|m| services.rooms.state_cache.is_joined(user_id, &m.room_id)) + .await + { Ok(true) } else { Err(Error::BadRequest( diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 3eb0d77ab..41ea1c80d 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -18,7 +18,7 @@ use crate::{service::pdu::PduBuilder, Ruma}; pub(crate) async fn create_leave_event_template_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -34,9 +34,10 @@ pub(crate) async fn create_leave_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, @@ -50,19 +51,23 @@ pub(crate) async fn create_leave_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + timestamp: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); diff --git a/src/api/server/openid.rs b/src/api/server/openid.rs index 6a1b99b75..9b54807a6 100644 --- a/src/api/server/openid.rs +++ b/src/api/server/openid.rs @@ -10,6 +10,9 @@ pub(crate) async fn get_openid_userinfo_route( State(services): State, body: Ruma, ) -> Result { Ok(get_openid_userinfo::v1::Response::new( - services.users.find_from_openid_token(&body.access_token)?, + services + .users + .find_from_openid_token(&body.access_token) + .await?, )) } diff --git a/src/api/server/query.rs b/src/api/server/query.rs index c2b78bded..348b8c6e9 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,7 +1,8 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{err, Error, Result}; +use futures::StreamExt; use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ @@ -23,15 +24,17 @@ pub(crate) async fn get_room_information_route( let room_id = services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; + .resolve_local_alias(&body.room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room alias not found."))))?; let mut servers: Vec = services .rooms .state_cache .room_servers(&room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.sort_unstable(); servers.dedup(); @@ -82,30 +85,31 @@ pub(crate) async fn get_profile_information_route( match &body.field { Some(ProfileField::DisplayName) => { - displayname = services.users.displayname(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); }, Some(ProfileField::AvatarUrl) => { - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); }, Some(custom_field) => { - if let Some(value) = services + if let Ok(value) = services .users - .profile_key(&body.user_id, custom_field.as_str())? + .profile_key(&body.user_id, custom_field.as_str()) + .await { custom_profile_fields.insert(custom_field.to_string(), value); } }, None => { - displayname = services.users.displayname(&body.user_id)?; - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; - tz = services.users.timezone(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); + tz = services.users.timezone(&body.user_id).await.ok(); custom_profile_fields = services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(); + .collect() + .await; }, } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 15f82faa7..bb4249881 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -2,7 +2,8 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant}; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug, debug_warn, err, trace, warn, Err}; +use conduit::{debug, debug_warn, err, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::error::ErrorKind, @@ -23,10 +24,13 @@ use tokio::sync::RwLock; use crate::{ services::Services, utils::{self}, - Error, Result, Ruma, + Ruma, }; -type ResolvedMap = BTreeMap>; +const PDU_LIMIT: usize = 50; +const EDU_LIMIT: usize = 100; + +type ResolvedMap = BTreeMap>; /// # `PUT /_matrix/federation/v1/send/{txnId}` /// @@ -44,12 +48,16 @@ pub(crate) async fn send_transaction_message_route( ))); } - if body.pdus.len() > 50_usize { - return Err!(Request(Forbidden("Not allowed to send more than 50 PDUs in one transaction"))); + if body.pdus.len() > PDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {PDU_LIMIT} PDUs in one transaction" + ))); } - if body.edus.len() > 100_usize { - return Err!(Request(Forbidden("Not allowed to send more than 100 EDUs in one transaction"))); + if body.edus.len() > EDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {EDU_LIMIT} EDUs in one transaction" + ))); } let txn_start_time = Instant::now(); @@ -62,8 +70,8 @@ pub(crate) async fn send_transaction_message_route( "Starting txn", ); - let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await?; - handle_edus(&services, &client, &body, origin).await?; + let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await; + handle_edus(&services, &client, &body, origin).await; debug!( pdus = ?body.pdus.len(), @@ -85,10 +93,10 @@ pub(crate) async fn send_transaction_message_route( async fn handle_pdus( services: &Services, _client: &IpAddr, body: &Ruma, origin: &ServerName, txn_start_time: &Instant, -) -> Result { +) -> ResolvedMap { let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); for pdu in &body.pdus { - parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) { + parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { Ok(t) => t, Err(e) => { debug_warn!("Could not parse PDU: {e}"); @@ -151,38 +159,34 @@ async fn handle_pdus( } } - Ok(resolved_map) + resolved_map } async fn handle_edus( services: &Services, client: &IpAddr, body: &Ruma, origin: &ServerName, -) -> Result<()> { +) { for edu in body .edus .iter() .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) { match edu { - Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await?, - Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await?, - Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await?, - Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await?, - Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await?, - Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await?, + Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await, + Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await, + Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await, + Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await, + Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await, + Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await, Edu::_Custom(ref _custom) => { debug_warn!(?body.edus, "received custom/unknown EDU"); }, } } - - Ok(()) } -async fn handle_edu_presence( - services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent, -) -> Result<()> { +async fn handle_edu_presence(services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent) { if !services.globals.allow_incoming_presence() { - return Ok(()); + return; } for update in presence.push { @@ -194,23 +198,24 @@ async fn handle_edu_presence( continue; } - services.presence.set_presence( - &update.user_id, - &update.presence, - Some(update.currently_active), - Some(update.last_active_ago), - update.status_msg.clone(), - )?; + services + .presence + .set_presence( + &update.user_id, + &update.presence, + Some(update.currently_active), + Some(update.last_active_ago), + update.status_msg.clone(), + ) + .await + .log_err() + .ok(); } - - Ok(()) } -async fn handle_edu_receipt( - services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent, -) -> Result<()> { +async fn handle_edu_receipt(services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) { if !services.globals.allow_incoming_read_receipts() { - return Ok(()); + return; } for (room_id, room_updates) in receipt.receipts { @@ -218,6 +223,7 @@ async fn handle_edu_receipt( .rooms .event_handler .acl_check(origin, &room_id) + .await .is_err() { debug_warn!( @@ -240,8 +246,8 @@ async fn handle_edu_receipt( .rooms .state_cache .room_members(&room_id) - .filter_map(Result::ok) - .any(|member| member.server_name() == user_id.server_name()) + .ready_any(|member| member.server_name() == user_id.server_name()) + .await { for event_id in &user_updates.event_ids { let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]); @@ -255,7 +261,8 @@ async fn handle_edu_receipt( services .rooms .read_receipt - .readreceipt_update(&user_id, &room_id, &event)?; + .readreceipt_update(&user_id, &room_id, &event) + .await; } } else { debug_warn!( @@ -266,15 +273,11 @@ async fn handle_edu_receipt( } } } - - Ok(()) } -async fn handle_edu_typing( - services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent, -) -> Result<()> { +async fn handle_edu_typing(services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent) { if !services.globals.config.allow_incoming_typing { - return Ok(()); + return; } if typing.user_id.server_name() != origin { @@ -282,26 +285,28 @@ async fn handle_edu_typing( %typing.user_id, %origin, "received typing EDU for user not belonging to origin" ); - return Ok(()); + return; } if services .rooms .event_handler .acl_check(typing.user_id.server_name(), &typing.room_id) + .await .is_err() { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for ACL'd user's server" ); - return Ok(()); + return; } if services .rooms .state_cache - .is_joined(&typing.user_id, &typing.room_id)? + .is_joined(&typing.user_id, &typing.room_id) + .await { if typing.typing { let timeout = utils::millis_since_unix_epoch().saturating_add( @@ -315,28 +320,29 @@ async fn handle_edu_typing( .rooms .typing .typing_add(&typing.user_id, &typing.room_id, timeout) - .await?; + .await + .log_err() + .ok(); } else { services .rooms .typing .typing_remove(&typing.user_id, &typing.room_id) - .await?; + .await + .log_err() + .ok(); } } else { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for user not in room" ); - return Ok(()); } - - Ok(()) } async fn handle_edu_device_list_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, -) -> Result<()> { +) { let DeviceListUpdateContent { user_id, .. @@ -347,17 +353,15 @@ async fn handle_edu_device_list_update( %user_id, %origin, "received device list update EDU for user not belonging to origin" ); - return Ok(()); + return; } - services.users.mark_device_key_update(&user_id)?; - - Ok(()) + services.users.mark_device_key_update(&user_id).await; } async fn handle_edu_direct_to_device( services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, -) -> Result<()> { +) { let DirectDeviceContent { sender, ev_type, @@ -370,45 +374,52 @@ async fn handle_edu_direct_to_device( %sender, %origin, "received direct to device EDU for user not belonging to origin" ); - return Ok(()); + return; } // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(&sender, None, &message_id)? - .is_some() + .existing_txnid(&sender, None, &message_id) + .await + .is_ok() { - return Ok(()); + return; } for (target_user_id, map) in &messages { for (target_device_id_maybe, event) in map { + let Ok(event) = event + .deserialize_as() + .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}"))))) + else { + continue; + }; + + let ev_type = ev_type.to_string(); match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))?, - )?; + services + .users + .add_to_device_event(&sender, target_user_id, target_device_id, &ev_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam("Event is invalid: {e}"))))?, - )?; - } + let (sender, ev_type, event) = (&sender, &ev_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender, + target_user_id, + target_device_id, + ev_type, + event.clone(), + ) + }) + .await; }, } } @@ -417,14 +428,12 @@ async fn handle_edu_direct_to_device( // Save transaction id with empty data services .transaction_ids - .add_txnid(&sender, None, &message_id, &[])?; - - Ok(()) + .add_txnid(&sender, None, &message_id, &[]); } async fn handle_edu_signing_key_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, -) -> Result<()> { +) { let SigningKeyUpdateContent { user_id, master_key, @@ -436,14 +445,15 @@ async fn handle_edu_signing_key_update( %user_id, %origin, "received signing key update EDU from server that does not belong to user's server" ); - return Ok(()); + return; } if let Some(master_key) = master_key { services .users - .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; + .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true) + .await + .log_err() + .ok(); } - - Ok(()) } diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index c4d016f61..639fcafd0 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -3,7 +3,8 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{pdu::gen_event_id_canonical_json, warn, Error, Result}; +use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, events::{ @@ -22,27 +23,32 @@ use crate::Ruma; async fn create_join_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin server - services.rooms.event_handler.acl_check(origin, room_id)?; + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; // We need to return the state prior to joining, let's keep a reference to that // here let shortstatehash = services .rooms .state - .get_room_shortstatehash(room_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event state not found."))?; + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Request(NotFound("Event state not found."))))?; let pub_key_map = RwLock::new(BTreeMap::new()); // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json @@ -97,7 +103,8 @@ async fn create_join_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; // check if origin server is trying to send for another server if sender.server_name() != origin { @@ -126,7 +133,9 @@ async fn create_join_event( if content .join_authorized_via_users_server .is_some_and(|user| services.globals.user_is_local(&user)) - && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default() + && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id) + .await + .unwrap_or_default() { ruma::signatures::hash_and_sign_event( services.globals.server_name().as_str(), @@ -158,12 +167,14 @@ async fn create_join_event( .mutex_federation .lock(room_id) .await; + let pdu_id: Vec = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) .await? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + drop(mutex_lock); let state_ids = services @@ -171,29 +182,43 @@ async fn create_join_event( .state_accessor .state_full_ids(shortstatehash) .await?; - let auth_chain_ids = services + + let state = state_ids + .iter() + .try_stream() + .and_then(|(_, event_id)| services.rooms.timeline.get_pdu_json(event_id)) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() + .await?; + + let auth_chain = services .rooms .auth_chain .event_ids_iter(room_id, state_ids.values().cloned().collect()) + .await? + .map(Ok) + .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; Ok(create_join_event::v1::RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, + state, // Event field is required if the room version supports restricted join rules. - event: Some( - to_raw_value(&CanonicalJsonValue::Object(value)) - .expect("To raw json should not fail since only change was adding signature"), - ), + event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(), }) } diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index e77c5d78a..81f41af07 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{utils::ReadyExt, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -49,18 +49,22 @@ pub(crate) async fn create_leave_event_v2_route( async fn create_leave_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result<()> { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin - services.rooms.event_handler.acl_check(origin, room_id)?; + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; let pub_key_map = RwLock::new(BTreeMap::new()); // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json return Err(Error::BadRequest( @@ -114,7 +118,8 @@ async fn create_leave_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; if sender.server_name() != origin { return Err(Error::BadRequest( @@ -173,10 +178,9 @@ async fn create_leave_event( .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| !services.globals.server_is_ours(server)); + .ready_filter(|server| !services.globals.server_is_ours(server)); - services.sending.send_pdu_servers(servers, &pdu_id)?; + services.sending.send_pdu_servers(servers, &pdu_id).await?; Ok(()) } diff --git a/src/api/server/state.rs b/src/api/server/state.rs index d215236af..37a14a3f3 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{Error, Result}; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; +use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; +use ruma::api::federation::event::get_room_state; use crate::Ruma; @@ -17,56 +18,66 @@ pub(crate) async fn get_room_state_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("PDU state not found."))))?; let pdus = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? - .into_values() - .map(|id| { + .await + .log_err() + .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))? + .values() + .try_stream() + .and_then(|id| services.rooms.timeline.get_pdu_json(id)) + .and_then(|pdu| { services .sending - .convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()) + .convert_to_outgoing_federation_event(pdu) + .map(Ok) }) - .collect(); + .try_collect() + .await?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .await? + .map(Ok) + .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; Ok(get_room_state::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| { - services - .rooms - .timeline - .get_pdu_json(&id) - .ok()? - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - }) - .collect(), + auth_chain, pdus, }) } diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index d22f2df4a..95ca65aa7 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,9 +1,11 @@ use std::sync::Arc; use axum::extract::State; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids}; +use conduit::{err, Err}; +use futures::StreamExt; +use ruma::api::federation::event::get_room_state_ids; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// @@ -17,31 +19,35 @@ pub(crate) async fn get_room_state_ids_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; let pdu_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? + .await + .map_err(|_| err!(Request(NotFound("State ids not found"))))? .into_values() .map(|id| (*id).to_owned()) .collect(); @@ -50,10 +56,13 @@ pub(crate) async fn get_room_state_ids_route( .rooms .auth_chain .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .await? + .map(|id| (*id).to_owned()) + .collect() + .await; Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), + auth_chain_ids, pdu_ids, }) } diff --git a/src/api/server/user.rs b/src/api/server/user.rs index e9a400a79..0718da580 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -1,5 +1,6 @@ use axum::extract::State; use conduit::{Error, Result}; +use futures::{FutureExt, StreamExt, TryFutureExt}; use ruma::api::{ client::error::ErrorKind, federation::{ @@ -28,41 +29,51 @@ pub(crate) async fn get_devices_route( let origin = body.origin.as_ref().expect("server is authenticated"); + let user_id = &body.user_id; Ok(get_devices::v1::Response { - user_id: body.user_id.clone(), + user_id: user_id.clone(), stream_id: services .users - .get_devicelist_version(&body.user_id)? + .get_devicelist_version(user_id) + .await .unwrap_or(0) - .try_into() - .expect("version will not grow that large"), + .try_into()?, devices: services .users - .all_devices_metadata(&body.user_id) - .filter_map(Result::ok) - .filter_map(|metadata| { - let device_id_string = metadata.device_id.as_str().to_owned(); + .all_devices_metadata(user_id) + .filter_map(|metadata| async move { + let device_id = metadata.device_id.clone(); + let device_id_clone = device_id.clone(); + let device_id_string = device_id.as_str().to_owned(); let device_display_name = if services.globals.allow_device_name_federation() { - metadata.display_name + metadata.display_name.clone() } else { Some(device_id_string) }; - Some(UserDevice { - keys: services - .users - .get_device_keys(&body.user_id, &metadata.device_id) - .ok()??, - device_id: metadata.device_id, - device_display_name, - }) + + services + .users + .get_device_keys(user_id, &device_id_clone) + .map_ok(|keys| UserDevice { + device_id, + keys, + device_display_name, + }) + .map(Result::ok) + .await }) - .collect(), + .collect() + .await, master_key: services .users - .get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_master_key(None, &body.user_id, &|u| u.server_name() == origin) + .await + .ok(), self_signing_key: services .users - .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin) + .await + .ok(), }) } diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 713647342..cb957bc90 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -67,6 +67,7 @@ ctor.workspace = true cyborgtime.workspace = true either.workspace = true figment.workspace = true +futures.workspace = true http-body-util.workspace = true http.workspace = true image.workspace = true diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 48b9b58ff..79e3d5b40 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -86,7 +86,7 @@ pub enum Error { #[error("There was a problem with the '{0}' directive in your configuration: {1}")] Config(&'static str, Cow<'static, str>), #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists + Conflict(Cow<'static, str>), // This is only needed for when a room alias already exists #[error(transparent)] ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError), #[error("{0}")] @@ -107,6 +107,8 @@ pub enum Error { Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), #[error(transparent)] Ruma(#[from] ruma::api::client::error::Error), + #[error(transparent)] + StateRes(#[from] ruma::state_res::Error), #[error("uiaa")] Uiaa(ruma::api::client::uiaa::UiaaInfo), diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 439c831a5..cf9ffe645 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -3,8 +3,6 @@ mod count; use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; -pub use builder::PduBuilder; -pub use count::PduCount; use ruma::{ canonical_json::redact_content_in_place, events::{ @@ -23,7 +21,8 @@ use serde_json::{ value::{to_raw_value, RawValue as RawJsonValue}, }; -use crate::{err, warn, Error}; +pub use self::{builder::PduBuilder, count::PduCount}; +use crate::{err, warn, Error, Result}; #[derive(Deserialize)] struct ExtractRedactedBecause { @@ -65,11 +64,12 @@ pub struct PduEvent { impl PduEvent { #[tracing::instrument(skip(self), level = "debug")] - pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { + pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> { self.unsigned = None; let mut content = serde_json::from_str(self.content.get()) .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; + redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; @@ -98,31 +98,38 @@ impl PduEvent { unsigned.redacted_because.is_some() } - pub fn remove_transaction_id(&mut self) -> crate::Result<()> { - if let Some(unsigned) = &self.unsigned { - let mut unsigned: BTreeMap> = serde_json::from_str(unsigned.get()) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - unsigned.remove("transaction_id"); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); - } + pub fn remove_transaction_id(&mut self) -> Result<()> { + let Some(unsigned) = &self.unsigned else { + return Ok(()); + }; + + let mut unsigned: BTreeMap> = + serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + unsigned.remove("transaction_id"); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); Ok(()) } - pub fn add_age(&mut self) -> crate::Result<()> { + pub fn add_age(&mut self) -> Result<()> { let mut unsigned: BTreeMap> = self .unsigned .as_ref() .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; // deliberately allowing for the possibility of negative age let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); let then: i128 = self.origin_server_ts.into(); let this_age = now.saturating_sub(then); - unsigned.insert("age".to_owned(), to_raw_value(&this_age).unwrap()); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); + unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid")); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); Ok(()) } @@ -369,9 +376,9 @@ impl state_res::Event for PduEvent { fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - fn prev_events(&self) -> Box + '_> { Box::new(self.prev_events.iter()) } + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.prev_events.iter() } - fn auth_events(&self) -> Box + '_> { Box::new(self.auth_events.iter()) } + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.auth_events.iter() } fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } } @@ -395,7 +402,7 @@ impl Ord for PduEvent { /// CanonicalJsonValue>`. pub fn gen_event_id_canonical_json( pdu: &RawJsonValue, room_version_id: &RoomVersionId, -) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { +) -> Result<(OwnedEventId, CanonicalJsonObject)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; diff --git a/src/core/result/log_debug_err.rs b/src/core/result/log_debug_err.rs index be2000aed..8835afd19 100644 --- a/src/core/result/log_debug_err.rs +++ b/src/core/result/log_debug_err.rs @@ -1,18 +1,14 @@ -use std::fmt; +use std::fmt::Debug; use tracing::Level; use super::{DebugInspect, Result}; use crate::error; -pub trait LogDebugErr -where - E: fmt::Debug, -{ +pub trait LogDebugErr { #[must_use] fn err_debug_log(self, level: Level) -> Self; - #[inline] #[must_use] fn log_debug_err(self) -> Self where @@ -22,15 +18,9 @@ where } } -impl LogDebugErr for Result -where - E: fmt::Debug, -{ +impl LogDebugErr for Result { #[inline] - fn err_debug_log(self, level: Level) -> Self - where - Self: Sized, - { + fn err_debug_log(self, level: Level) -> Self { self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level)) } } diff --git a/src/core/result/log_err.rs b/src/core/result/log_err.rs index 079571f56..374a5e596 100644 --- a/src/core/result/log_err.rs +++ b/src/core/result/log_err.rs @@ -1,18 +1,14 @@ -use std::fmt; +use std::fmt::Display; use tracing::Level; use super::Result; use crate::error; -pub trait LogErr -where - E: fmt::Display, -{ +pub trait LogErr { #[must_use] fn err_log(self, level: Level) -> Self; - #[inline] #[must_use] fn log_err(self) -> Self where @@ -22,15 +18,7 @@ where } } -impl LogErr for Result -where - E: fmt::Display, -{ +impl LogErr for Result { #[inline] - fn err_log(self, level: Level) -> Self - where - Self: Sized, - { - self.inspect_err(|error| error::inspect_log_level(&error, level)) - } + fn err_log(self, level: Level) -> Self { self.inspect_err(|error| error::inspect_log_level(&error, level)) } } diff --git a/src/core/utils/algorithm.rs b/src/core/utils/algorithm.rs deleted file mode 100644 index 9bc1bc8a7..000000000 --- a/src/core/utils/algorithm.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::cmp::Ordering; - -#[allow(clippy::impl_trait_in_params)] -pub fn common_elements( - mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, -) -> Option>> { - let first_iterator = iterators.next()?; - let mut other_iterators = iterators.map(Iterator::peekable).collect::>(); - - Some(first_iterator.filter(move |target| { - other_iterators.iter_mut().all(|it| { - while let Some(element) = it.peek() { - match check_order(element, target) { - Ordering::Greater => return false, // We went too far - Ordering::Equal => return true, // Element is in both iters - Ordering::Less => { - // Keep searching - it.next(); - }, - } - } - false - }) - })) -} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 03b755e9e..b1ea3709d 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,4 +1,3 @@ -pub mod algorithm; pub mod bytes; pub mod content_disposition; pub mod debug; @@ -9,25 +8,30 @@ pub mod json; pub mod math; pub mod mutex_map; pub mod rand; +pub mod set; +pub mod stream; pub mod string; pub mod sys; mod tests; pub mod time; +pub use ::conduit_macros::implement; pub use ::ctor::{ctor, dtor}; -pub use algorithm::common_elements; -pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; -pub use conduit_macros::implement; -pub use debug::slice_truncated as debug_slice_truncated; -pub use hash::calculate_hash; -pub use html::Escape as HtmlEscape; -pub use json::{deserialize_from_str, to_canonical_object}; -pub use math::clamp; -pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; -pub use rand::string as random_string; -pub use string::{str_from_bytes, string_from_bytes}; -pub use sys::available_parallelism; -pub use time::now_millis as millis_since_unix_epoch; + +pub use self::{ + bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}, + debug::slice_truncated as debug_slice_truncated, + hash::calculate_hash, + html::Escape as HtmlEscape, + json::{deserialize_from_str, to_canonical_object}, + math::clamp, + mutex_map::{Guard as MutexMapGuard, MutexMap}, + rand::string as random_string, + stream::{IterStream, ReadyExt, TryReadyExt}, + string::{str_from_bytes, string_from_bytes}, + sys::available_parallelism, + time::now_millis as millis_since_unix_epoch, +}; #[inline] pub fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } diff --git a/src/core/utils/set.rs b/src/core/utils/set.rs new file mode 100644 index 000000000..563f9df5c --- /dev/null +++ b/src/core/utils/set.rs @@ -0,0 +1,47 @@ +use std::cmp::{Eq, Ord}; + +use crate::{is_equal_to, is_less_than}; + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs do not have to +/// be sorted. If inputs are sorted a more optimized function is available in +/// this suite and should be used. +pub fn intersection(mut input: Iters) -> impl Iterator + Send +where + Iters: Iterator + Clone + Send, + Iter: Iterator + Send, + Item: Eq + Send, +{ + input.next().into_iter().flat_map(move |first| { + let input = input.clone(); + first.filter(move |targ| { + input + .clone() + .all(|mut other| other.any(is_equal_to!(*targ))) + }) + }) +} + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs must be sorted. +pub fn intersection_sorted(mut input: Iters) -> impl Iterator + Send +where + Iters: Iterator + Clone + Send, + Iter: Iterator + Send, + Item: Eq + Ord + Send, +{ + input.next().into_iter().flat_map(move |first| { + let mut input = input.clone().collect::>(); + first.filter(move |targ| { + input.iter_mut().all(|it| { + it.by_ref() + .skip_while(is_less_than!(targ)) + .peekable() + .peek() + .is_some_and(is_equal_to!(targ)) + }) + }) + }) +} diff --git a/src/core/utils/stream/cloned.rs b/src/core/utils/stream/cloned.rs new file mode 100644 index 000000000..d6a0e6470 --- /dev/null +++ b/src/core/utils/stream/cloned.rs @@ -0,0 +1,20 @@ +use std::clone::Clone; + +use futures::{stream::Map, Stream, StreamExt}; + +pub trait Cloned<'a, T, S> +where + S: Stream, + T: Clone + 'a, +{ + fn cloned(self) -> Map T>; +} + +impl<'a, T, S> Cloned<'a, T, S> for S +where + S: Stream, + T: Clone + 'a, +{ + #[inline] + fn cloned(self) -> Map T> { self.map(Clone::clone) } +} diff --git a/src/core/utils/stream/expect.rs b/src/core/utils/stream/expect.rs new file mode 100644 index 000000000..3ab7181a8 --- /dev/null +++ b/src/core/utils/stream/expect.rs @@ -0,0 +1,17 @@ +use futures::{Stream, StreamExt, TryStream}; + +use crate::Result; + +pub trait TryExpect<'a, Item> { + fn expect_ok(self) -> impl Stream + Send + 'a; +} + +impl<'a, T, Item> TryExpect<'a, Item> for T +where + T: Stream> + TryStream + Send + 'a, +{ + #[inline] + fn expect_ok(self: T) -> impl Stream + Send + 'a { + self.map(|res| res.expect("stream expectation failure")) + } +} diff --git a/src/core/utils/stream/ignore.rs b/src/core/utils/stream/ignore.rs new file mode 100644 index 000000000..997aa4ba4 --- /dev/null +++ b/src/core/utils/stream/ignore.rs @@ -0,0 +1,21 @@ +use futures::{future::ready, Stream, StreamExt, TryStream}; + +use crate::{Error, Result}; + +pub trait TryIgnore<'a, Item> { + fn ignore_err(self) -> impl Stream + Send + 'a; + + fn ignore_ok(self) -> impl Stream + Send + 'a; +} + +impl<'a, T, Item> TryIgnore<'a, Item> for T +where + T: Stream> + TryStream + Send + 'a, + Item: Send + 'a, +{ + #[inline] + fn ignore_err(self: T) -> impl Stream + Send + 'a { self.filter_map(|res| ready(res.ok())) } + + #[inline] + fn ignore_ok(self: T) -> impl Stream + Send + 'a { self.filter_map(|res| ready(res.err())) } +} diff --git a/src/core/utils/stream/iter_stream.rs b/src/core/utils/stream/iter_stream.rs new file mode 100644 index 000000000..69edf64f5 --- /dev/null +++ b/src/core/utils/stream/iter_stream.rs @@ -0,0 +1,27 @@ +use futures::{ + stream, + stream::{Stream, TryStream}, + StreamExt, +}; + +pub trait IterStream { + /// Convert an Iterator into a Stream + fn stream(self) -> impl Stream::Item> + Send; + + /// Convert an Iterator into a TryStream + fn try_stream(self) -> impl TryStream::Item, Error = crate::Error> + Send; +} + +impl IterStream for I +where + I: IntoIterator + Send, + ::IntoIter: Send, +{ + #[inline] + fn stream(self) -> impl Stream::Item> + Send { stream::iter(self) } + + #[inline] + fn try_stream(self) -> impl TryStream::Item, Error = crate::Error> + Send { + self.stream().map(Ok) + } +} diff --git a/src/core/utils/stream/mod.rs b/src/core/utils/stream/mod.rs new file mode 100644 index 000000000..781bd5223 --- /dev/null +++ b/src/core/utils/stream/mod.rs @@ -0,0 +1,13 @@ +mod cloned; +mod expect; +mod ignore; +mod iter_stream; +mod ready; +mod try_ready; + +pub use cloned::Cloned; +pub use expect::TryExpect; +pub use ignore::TryIgnore; +pub use iter_stream::IterStream; +pub use ready::ReadyExt; +pub use try_ready::TryReadyExt; diff --git a/src/core/utils/stream/ready.rs b/src/core/utils/stream/ready.rs new file mode 100644 index 000000000..13f730a7d --- /dev/null +++ b/src/core/utils/stream/ready.rs @@ -0,0 +1,109 @@ +//! Synchronous combinator extensions to futures::Stream + +use futures::{ + future::{ready, Ready}, + stream::{Any, Filter, FilterMap, Fold, ForEach, SkipWhile, Stream, StreamExt, TakeWhile}, +}; + +/// Synchronous combinators to augment futures::StreamExt. Most Stream +/// combinators take asynchronous arguments, but often only simple predicates +/// are required to steer a Stream like an Iterator. This suite provides a +/// convenience to reduce boilerplate by de-cluttering non-async predicates. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait ReadyExt +where + S: Stream + Send + ?Sized, + Self: Stream + Send + Sized, +{ + fn ready_any(self, f: F) -> Any, impl FnMut(S::Item) -> Ready> + where + F: Fn(S::Item) -> bool; + + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a; + + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(S::Item) -> Ready>> + where + F: Fn(S::Item) -> Option; + + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, S::Item) -> Ready> + where + F: Fn(T, S::Item) -> T; + + fn ready_for_each(self, f: F) -> ForEach, impl FnMut(S::Item) -> Ready<()>> + where + F: FnMut(S::Item); + + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a; + + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a; +} + +impl ReadyExt for S +where + S: Stream + Send + ?Sized, + Self: Stream + Send + Sized, +{ + #[inline] + fn ready_any(self, f: F) -> Any, impl FnMut(S::Item) -> Ready> + where + F: Fn(S::Item) -> bool, + { + self.any(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a, + { + self.filter(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(S::Item) -> Ready>> + where + F: Fn(S::Item) -> Option, + { + self.filter_map(move |t| ready(f(t))) + } + + #[inline] + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, S::Item) -> Ready> + where + F: Fn(T, S::Item) -> T, + { + self.fold(init, move |a, t| ready(f(a, t))) + } + + #[inline] + #[allow(clippy::unit_arg)] + fn ready_for_each(self, mut f: F) -> ForEach, impl FnMut(S::Item) -> Ready<()>> + where + F: FnMut(S::Item), + { + self.for_each(move |t| ready(f(t))) + } + + #[inline] + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a, + { + self.take_while(move |t| ready(f(t))) + } + + #[inline] + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a, + { + self.skip_while(move |t| ready(f(t))) + } +} diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs new file mode 100644 index 000000000..ab37d9b30 --- /dev/null +++ b/src/core/utils/stream/try_ready.rs @@ -0,0 +1,35 @@ +//! Synchronous combinator extensions to futures::TryStream + +use futures::{ + future::{ready, Ready}, + stream::{AndThen, TryStream, TryStreamExt}, +}; + +use crate::Result; + +/// Synchronous combinators to augment futures::TryStreamExt. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait TryReadyExt +where + S: TryStream> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + fn ready_and_then(self, f: F) -> AndThen>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result; +} + +impl TryReadyExt for S +where + S: TryStream> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + #[inline] + fn ready_and_then(self, f: F) -> AndThen>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result, + { + self.and_then(move |t| ready(f(t))) + } +} diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index 5880470a3..84d35936e 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -107,3 +107,133 @@ async fn mutex_map_contend() { tokio::try_join!(join_b, join_a).expect("joined"); assert!(map.is_empty(), "Must be empty"); } + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_none() { + use utils::set::intersection; + + let a: [&str; 0] = []; + let b: [&str; 0] = []; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a: [&str; 0] = []; + let b = ["abc", "def"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a = ["foo", "bar", "baz"]; + let b = ["def", "hij", "klm", "nop"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_all() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["foo", "bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar"].iter())); + let i = [b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["foo", "bar", "baz"]; + let b = ["baz", "foo", "bar"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar", "baz"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_some() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["abcdef", "foo", "hijkl", "abc"]; + let b = ["hij", "bar", "baz", "abc", "foo"]; + let c = ["abc", "xyz", "foo", "ghi"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "abc"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_some() { + use utils::set::intersection_sorted; + + let a = ["bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + + let a = ["aaa", "ccc", "eee", "ggg"]; + let b = ["aaa", "bbb", "ccc", "ddd", "eee"]; + let c = ["bbb", "ccc", "eee", "fff"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["ccc", "eee"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_all() { + use utils::set::intersection_sorted; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["bar", "foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + let i = [b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["bar", "baz", "foo"]; + let b = ["bar", "baz", "foo"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "baz", "foo"].iter())); +} diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index 34d98416d..b5eb76126 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -37,8 +37,11 @@ zstd_compression = [ [dependencies] conduit-core.workspace = true const-str.workspace = true +futures.workspace = true log.workspace = true rust-rocksdb.workspace = true +serde.workspace = true +serde_json.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/src/database/database.rs b/src/database/database.rs index c357d50f2..ac6f62e90 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -37,7 +37,7 @@ impl Database { pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } #[inline] - pub fn iter_maps(&self) -> impl Iterator + '_ { self.map.iter() } + pub fn iter_maps(&self) -> impl Iterator + Send + '_ { self.map.iter() } } impl Index<&str> for Database { diff --git a/src/database/de.rs b/src/database/de.rs new file mode 100644 index 000000000..8ce25aa31 --- /dev/null +++ b/src/database/de.rs @@ -0,0 +1,261 @@ +use conduit::{checked, debug::DebugInspect, err, utils::string, Error, Result}; +use serde::{ + de, + de::{DeserializeSeed, Visitor}, + Deserialize, +}; + +pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result +where + T: Deserialize<'a>, +{ + let mut deserializer = Deserializer { + buf, + pos: 0, + }; + + T::deserialize(&mut deserializer).debug_inspect(|_| { + deserializer + .finished() + .expect("deserialization failed to consume trailing bytes"); + }) +} + +pub(crate) struct Deserializer<'de> { + buf: &'de [u8], + pos: usize, +} + +/// Directive to ignore a record. This type can be used to skip deserialization +/// until the next separator is found. +#[derive(Debug, Deserialize)] +pub struct Ignore; + +impl<'de> Deserializer<'de> { + const SEP: u8 = b'\xFF'; + + fn finished(&self) -> Result<()> { + let pos = self.pos; + let len = self.buf.len(); + let parsed = &self.buf[0..pos]; + let unparsed = &self.buf[pos..]; + let remain = checked!(len - pos)?; + let trailing_sep = remain == 1 && unparsed[0] == Self::SEP; + (remain == 0 || trailing_sep) + .then_some(()) + .ok_or(err!(SerdeDe( + "{remain} trailing of {len} bytes not deserialized.\n{parsed:?}\n{unparsed:?}", + ))) + } + + #[inline] + fn record_next(&mut self) -> &'de [u8] { + self.buf[self.pos..] + .split(|b| *b == Deserializer::SEP) + .inspect(|record| self.inc_pos(record.len())) + .next() + .expect("remainder of buf even if SEP was not found") + } + + #[inline] + fn record_trail(&mut self) -> &'de [u8] { + let record = &self.buf[self.pos..]; + self.inc_pos(record.len()); + record + } + + #[inline] + fn record_start(&mut self) { + let started = self.pos != 0; + debug_assert!( + !started || self.buf[self.pos] == Self::SEP, + "Missing expected record separator at current position" + ); + + self.inc_pos(started.into()); + } + + #[inline] + fn inc_pos(&mut self, n: usize) { + self.pos = self.pos.saturating_add(n); + debug_assert!(self.pos <= self.buf.len(), "pos out of range"); + } +} + +impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn deserialize_map(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Map not implemented") + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_tuple_struct(self, _name: &'static str, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_struct( + self, _name: &'static str, _fields: &'static [&'static str], _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Struct not implemented") + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + match name { + "Ignore" => self.record_next(), + _ => unimplemented!("Unrecognized deserialization Directive {name:?}"), + }; + + visitor.visit_unit() + } + + fn deserialize_newtype_struct(self, _name: &'static str, _visitor: V) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Newtype Struct not implemented") + } + + fn deserialize_enum( + self, _name: &'static str, _variants: &'static [&'static str], _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Enum not implemented") + } + + fn deserialize_option>(self, _visitor: V) -> Result { + unimplemented!("deserialize Option not implemented") + } + + fn deserialize_bool>(self, _visitor: V) -> Result { + unimplemented!("deserialize bool not implemented") + } + + fn deserialize_i8>(self, _visitor: V) -> Result { + unimplemented!("deserialize i8 not implemented") + } + + fn deserialize_i16>(self, _visitor: V) -> Result { + unimplemented!("deserialize i16 not implemented") + } + + fn deserialize_i32>(self, _visitor: V) -> Result { + unimplemented!("deserialize i32 not implemented") + } + + fn deserialize_i64>(self, visitor: V) -> Result { + let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; + self.pos = self.pos.saturating_add(size_of::()); + visitor.visit_i64(i64::from_be_bytes(bytes)) + } + + fn deserialize_u8>(self, _visitor: V) -> Result { + unimplemented!("deserialize u8 not implemented") + } + + fn deserialize_u16>(self, _visitor: V) -> Result { + unimplemented!("deserialize u16 not implemented") + } + + fn deserialize_u32>(self, _visitor: V) -> Result { + unimplemented!("deserialize u32 not implemented") + } + + fn deserialize_u64>(self, visitor: V) -> Result { + let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; + self.pos = self.pos.saturating_add(size_of::()); + visitor.visit_u64(u64::from_be_bytes(bytes)) + } + + fn deserialize_f32>(self, _visitor: V) -> Result { + unimplemented!("deserialize f32 not implemented") + } + + fn deserialize_f64>(self, _visitor: V) -> Result { + unimplemented!("deserialize f64 not implemented") + } + + fn deserialize_char>(self, _visitor: V) -> Result { + unimplemented!("deserialize char not implemented") + } + + fn deserialize_str>(self, visitor: V) -> Result { + let input = self.record_next(); + let out = string::str_from_bytes(input)?; + visitor.visit_borrowed_str(out) + } + + fn deserialize_string>(self, visitor: V) -> Result { + let input = self.record_next(); + let out = string::string_from_bytes(input)?; + visitor.visit_string(out) + } + + fn deserialize_bytes>(self, visitor: V) -> Result { + let input = self.record_trail(); + visitor.visit_borrowed_bytes(input) + } + + fn deserialize_byte_buf>(self, _visitor: V) -> Result { + unimplemented!("deserialize Byte Buf not implemented") + } + + fn deserialize_unit>(self, _visitor: V) -> Result { + unimplemented!("deserialize Unit Struct not implemented") + } + + fn deserialize_identifier>(self, _visitor: V) -> Result { + unimplemented!("deserialize Identifier not implemented") + } + + fn deserialize_ignored_any>(self, _visitor: V) -> Result { + unimplemented!("deserialize Ignored Any not implemented") + } + + fn deserialize_any>(self, _visitor: V) -> Result { + unimplemented!("deserialize any not implemented") + } +} + +impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + if self.pos >= self.buf.len() { + return Ok(None); + } + + self.record_start(); + seed.deserialize(&mut **self).map(Some) + } +} diff --git a/src/database/deserialized.rs b/src/database/deserialized.rs new file mode 100644 index 000000000..7da112d5f --- /dev/null +++ b/src/database/deserialized.rs @@ -0,0 +1,34 @@ +use std::convert::identity; + +use conduit::Result; +use serde::Deserialize; + +pub trait Deserialized { + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>; + + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>; + + #[inline] + fn deserialized(self) -> Result + where + T: for<'de> Deserialize<'de>, + Self: Sized, + { + self.map_de(identity::) + } + + #[inline] + fn deserialized_json(self) -> Result + where + T: for<'de> Deserialize<'de>, + Self: Sized, + { + self.map_json(identity::) + } +} diff --git a/src/database/engine.rs b/src/database/engine.rs index 3850c1d3f..067232e67 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -106,7 +106,7 @@ impl Engine { })) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "trace")] pub(crate) fn open_cf(&self, name: &str) -> Result>> { let mut cfs = self.cfs.lock().expect("locked"); if !cfs.contains(name) { diff --git a/src/database/handle.rs b/src/database/handle.rs index 0b45a75f0..89d87137a 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -1,6 +1,10 @@ -use std::ops::Deref; +use std::{fmt, fmt::Debug, ops::Deref}; +use conduit::Result; use rocksdb::DBPinnableSlice; +use serde::{Deserialize, Serialize, Serializer}; + +use crate::{keyval::deserialize_val, Deserialized, Slice}; pub struct Handle<'a> { val: DBPinnableSlice<'a>, @@ -14,14 +18,91 @@ impl<'a> From> for Handle<'a> { } } +impl Debug for Handle<'_> { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { + let val: &Slice = self; + let ptr = val.as_ptr(); + let len = val.len(); + write!(out, "Handle {{val: {{ptr: {ptr:?}, len: {len}}}}}") + } +} + +impl Serialize for Handle<'_> { + #[inline] + fn serialize(&self, serializer: S) -> Result { + let bytes: &Slice = self; + serializer.serialize_bytes(bytes) + } +} + impl Deref for Handle<'_> { - type Target = [u8]; + type Target = Slice; #[inline] fn deref(&self) -> &Self::Target { &self.val } } -impl AsRef<[u8]> for Handle<'_> { +impl AsRef for Handle<'_> { + #[inline] + fn as_ref(&self) -> &Slice { &self.val } +} + +impl Deserialized for Result> { + #[inline] + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self?.map_json(f) + } + #[inline] - fn as_ref(&self) -> &[u8] { &self.val } + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self?.map_de(f) + } +} + +impl<'a> Deserialized for Result<&'a Handle<'a>> { + #[inline] + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self.and_then(|handle| handle.map_json(f)) + } + + #[inline] + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self.and_then(|handle| handle.map_de(f)) + } +} + +impl<'a> Deserialized for &'a Handle<'a> { + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + serde_json::from_slice::(self.as_ref()) + .map_err(Into::into) + .map(f) + } + + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + deserialize_val(self.as_ref()).map(f) + } } diff --git a/src/database/iter.rs b/src/database/iter.rs deleted file mode 100644 index 4845e9773..000000000 --- a/src/database/iter.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::{iter::FusedIterator, sync::Arc}; - -use conduit::Result; -use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, Direction, IteratorMode, ReadOptions}; - -use crate::{ - engine::Db, - result, - slice::{OwnedKeyVal, OwnedKeyValPair}, - Engine, -}; - -type Cursor<'cursor> = DBRawIteratorWithThreadMode<'cursor, Db>; - -struct State<'cursor> { - cursor: Cursor<'cursor>, - direction: Direction, - valid: bool, - init: bool, -} - -impl<'cursor> State<'cursor> { - pub(crate) fn new( - db: &'cursor Arc, cf: &'cursor Arc, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - let mut cursor = db.db.raw_iterator_cf_opt(&**cf, opts); - let direction = into_direction(mode); - let valid = seek_init(&mut cursor, mode); - Self { - cursor, - direction, - valid, - init: true, - } - } -} - -pub struct Iter<'cursor> { - state: State<'cursor>, -} - -impl<'cursor> Iter<'cursor> { - pub(crate) fn new( - db: &'cursor Arc, cf: &'cursor Arc, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - Self { - state: State::new(db, cf, opts, mode), - } - } -} - -impl Iterator for Iter<'_> { - type Item = OwnedKeyValPair; - - fn next(&mut self) -> Option { - if !self.state.init && self.state.valid { - seek_next(&mut self.state.cursor, self.state.direction); - } else if self.state.init { - self.state.init = false; - } - - self.state - .cursor - .item() - .map(OwnedKeyVal::from) - .map(OwnedKeyVal::to_tuple) - .or_else(|| { - when_invalid(&mut self.state).expect("iterator invalidated due to error"); - None - }) - } -} - -impl FusedIterator for Iter<'_> {} - -fn when_invalid(state: &mut State<'_>) -> Result<()> { - state.valid = false; - result(state.cursor.status()) -} - -fn seek_next(cursor: &mut Cursor<'_>, direction: Direction) { - match direction { - Direction::Forward => cursor.next(), - Direction::Reverse => cursor.prev(), - } -} - -fn seek_init(cursor: &mut Cursor<'_>, mode: &IteratorMode<'_>) -> bool { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start => cursor.seek_to_first(), - End => cursor.seek_to_last(), - From(key, Forward) => cursor.seek(key), - From(key, Reverse) => cursor.seek_for_prev(key), - }; - - cursor.valid() -} - -fn into_direction(mode: &IteratorMode<'_>) -> Direction { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start | From(_, Forward) => Forward, - End | From(_, Reverse) => Reverse, - } -} diff --git a/src/database/keyval.rs b/src/database/keyval.rs new file mode 100644 index 000000000..c9d25977d --- /dev/null +++ b/src/database/keyval.rs @@ -0,0 +1,83 @@ +use conduit::Result; +use serde::Deserialize; + +use crate::de; + +pub(crate) type OwnedKeyVal = (Vec, Vec); +pub(crate) type OwnedKey = Vec; +pub(crate) type OwnedVal = Vec; + +pub type KeyVal<'a, K = &'a Slice, V = &'a Slice> = (Key<'a, K>, Val<'a, V>); +pub type Key<'a, T = &'a Slice> = T; +pub type Val<'a, T = &'a Slice> = T; + +pub type Slice = [u8]; + +#[inline] +pub(crate) fn _expect_deserialize<'a, K, V>(kv: Result>) -> KeyVal<'a, K, V> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + result_deserialize(kv).expect("failed to deserialize result key/val") +} + +#[inline] +pub(crate) fn _expect_deserialize_key<'a, K>(key: Result>) -> Key<'a, K> +where + K: Deserialize<'a>, +{ + result_deserialize_key(key).expect("failed to deserialize result key") +} + +#[inline] +pub(crate) fn result_deserialize<'a, K, V>(kv: Result>) -> Result> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + deserialize(kv?) +} + +#[inline] +pub(crate) fn result_deserialize_key<'a, K>(key: Result>) -> Result> +where + K: Deserialize<'a>, +{ + deserialize_key(key?) +} + +#[inline] +pub(crate) fn deserialize<'a, K, V>(kv: KeyVal<'a>) -> Result> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + Ok((deserialize_key::(kv.0)?, deserialize_val::(kv.1)?)) +} + +#[inline] +pub(crate) fn deserialize_key<'a, K>(key: Key<'a>) -> Result> +where + K: Deserialize<'a>, +{ + de::from_slice::(key) +} + +#[inline] +pub(crate) fn deserialize_val<'a, V>(val: Val<'a>) -> Result> +where + V: Deserialize<'a>, +{ + de::from_slice::(val) +} + +#[inline] +#[must_use] +pub fn to_owned(kv: KeyVal<'_>) -> OwnedKeyVal { (kv.0.to_owned(), kv.1.to_owned()) } + +#[inline] +pub fn key(kv: KeyVal<'_, K, V>) -> Key<'_, K> { kv.0 } + +#[inline] +pub fn val(kv: KeyVal<'_, K, V>) -> Val<'_, V> { kv.1 } diff --git a/src/database/map.rs b/src/database/map.rs index ddae8c813..a3cf32d4e 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -1,15 +1,39 @@ -use std::{ffi::CStr, future::Future, mem::size_of, pin::Pin, sync::Arc}; - -use conduit::{utils, Result}; -use rocksdb::{ - AsColumnFamilyRef, ColumnFamily, Direction, IteratorMode, ReadOptions, WriteBatchWithTransaction, WriteOptions, +mod count; +mod keys; +mod keys_from; +mod keys_prefix; +mod rev_keys; +mod rev_keys_from; +mod rev_keys_prefix; +mod rev_stream; +mod rev_stream_from; +mod rev_stream_prefix; +mod stream; +mod stream_from; +mod stream_prefix; + +use std::{ + convert::AsRef, + ffi::CStr, + fmt, + fmt::{Debug, Display}, + future::Future, + io::Write, + pin::Pin, + sync::Arc, }; +use conduit::{err, Result}; +use futures::future; +use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, WriteBatchWithTransaction, WriteOptions}; +use serde::Serialize; + use crate::{ - or_else, result, - slice::{Byte, Key, KeyVal, OwnedKey, OwnedKeyValPair, OwnedVal, Val}, + keyval::{OwnedKey, OwnedVal}, + ser, + util::{map_err, or_else}, watchers::Watchers, - Engine, Handle, Iter, + Engine, Handle, }; pub struct Map { @@ -21,8 +45,6 @@ pub struct Map { read_options: ReadOptions, } -type OwnedKeyValPairIter<'a> = Box + Send + 'a>; - impl Map { pub(crate) fn open(db: &Arc, name: &str) -> Result> { Ok(Arc::new(Self { @@ -35,162 +57,158 @@ impl Map { })) } - pub fn get(&self, key: &Key) -> Result>> { - let read_options = &self.read_options; - let res = self.db.db.get_pinned_cf_opt(&self.cf(), key, read_options); - - Ok(result(res)?.map(Handle::from)) + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn del(&self, key: &K) + where + K: Serialize + ?Sized + Debug, + { + let mut buf = Vec::::with_capacity(64); + self.bdel(key, &mut buf); } - pub fn multi_get(&self, keys: &[&Key]) -> Result>> { - // Optimization can be `true` if key vector is pre-sorted **by the column - // comparator**. - const SORTED: bool = false; - - let mut ret: Vec> = Vec::with_capacity(keys.len()); - let read_options = &self.read_options; - for res in self - .db - .db - .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) - { - match res { - Ok(Some(res)) => ret.push(Some((*res).to_vec())), - Ok(None) => ret.push(None), - Err(e) => return or_else(e), - } - } - - Ok(ret) + #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] + pub fn bdel(&self, key: &K, buf: &mut B) + where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, + { + let key = ser::serialize(buf, key).expect("failed to serialize deletion key"); + self.remove(&key); } - pub fn insert(&self, key: &Key, value: &Val) -> Result<()> { + #[tracing::instrument(level = "trace")] + pub fn remove(&self, key: &K) + where + K: AsRef<[u8]> + ?Sized + Debug, + { let write_options = &self.write_options; self.db .db - .put_cf_opt(&self.cf(), key, value, write_options) - .or_else(or_else)?; + .delete_cf_opt(&self.cf(), key, write_options) + .or_else(or_else) + .expect("database remove error"); if !self.db.corked() { - self.db.flush()?; + self.db.flush().expect("database flush error"); } - - self.watchers.wake(key); - - Ok(()) } - pub fn insert_batch<'a, I>(&'a self, iter: I) -> Result<()> + #[tracing::instrument(skip(self, value), fields(%self), level = "trace")] + pub fn insert(&self, key: &K, value: &V) where - I: Iterator>, + K: AsRef<[u8]> + ?Sized + Debug, + V: AsRef<[u8]> + ?Sized, { - let mut batch = WriteBatchWithTransaction::::default(); - for KeyVal(key, value) in iter { - batch.put_cf(&self.cf(), key, value); - } - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn remove(&self, key: &Key) -> Result<()> { - let write_options = &self.write_options; - let res = self.db.db.delete_cf_opt(&self.cf(), key, write_options); + self.db + .db + .put_cf_opt(&self.cf(), key, value, write_options) + .or_else(or_else) + .expect("database insert error"); if !self.db.corked() { - self.db.flush()?; + self.db.flush().expect("database flush error"); } - result(res) + self.watchers.wake(key.as_ref()); } - pub fn remove_batch<'a, I>(&'a self, iter: I) -> Result<()> + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn insert_batch<'a, I, K, V>(&'a self, iter: I) where - I: Iterator, + I: Iterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, + V: AsRef<[u8]> + Sized + 'a, { let mut batch = WriteBatchWithTransaction::::default(); - for key in iter { - batch.delete_cf(&self.cf(), key); + for (key, val) in iter { + batch.put_cf(&self.cf(), key.as_ref(), val.as_ref()); } let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); + self.db + .db + .write_opt(batch, write_options) + .or_else(or_else) + .expect("database insert batch error"); if !self.db.corked() { - self.db.flush()?; + self.db.flush().expect("database flush error"); } - - result(res) } - pub fn iter(&self) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::Start; - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) - } - - pub fn iter_from(&self, from: &Key, reverse: bool) -> OwnedKeyValPairIter<'_> { - let direction = if reverse { - Direction::Reverse - } else { - Direction::Forward - }; - let mode = IteratorMode::From(from, direction); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn qry(&self, key: &K) -> impl Future>> + Send + where + K: Serialize + ?Sized + Debug, + { + let mut buf = Vec::::with_capacity(64); + self.bqry(key, &mut buf) } - pub fn scan_prefix(&self, prefix: OwnedKey) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::From(&prefix, Direction::Forward); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode).take_while(move |(k, _)| k.starts_with(&prefix))) + #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] + pub fn bqry(&self, key: &K, buf: &mut B) -> impl Future>> + Send + where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, + { + let key = ser::serialize(buf, key).expect("failed to serialize query key"); + let val = self.get(key); + future::ready(val) } - pub fn increment(&self, key: &Key) -> Result<[Byte; size_of::()]> { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - self.insert(key, &new)?; - - if !self.db.corked() { - self.db.flush()?; - } - - Ok(new) + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn get(&self, key: &K) -> Result> + where + K: AsRef<[u8]> + ?Sized + Debug, + { + self.db + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options) + .map_err(map_err)? + .map(Handle::from) + .ok_or(err!(Request(NotFound("Not found in database")))) } - pub fn increment_batch<'a, I>(&'a self, iter: I) -> Result<()> + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn multi_get<'a, I, K>(&self, keys: I) -> Vec> where - I: Iterator, + I: Iterator + ExactSizeIterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, { - let mut batch = WriteBatchWithTransaction::::default(); - for key in iter { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - batch.put_cf(&self.cf(), key, new); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); + // Optimization can be `true` if key vector is pre-sorted **by the column + // comparator**. + const SORTED: bool = false; - if !self.db.corked() { - self.db.flush()?; + let mut ret: Vec> = Vec::with_capacity(keys.len()); + let read_options = &self.read_options; + for res in self + .db + .db + .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) + { + match res { + Ok(Some(res)) => ret.push(Some((*res).to_vec())), + Ok(None) => ret.push(None), + Err(e) => or_else(e).expect("database multiget error"), + } } - result(res) + ret } - pub fn watch_prefix<'a>(&'a self, prefix: &Key) -> Pin + Send + 'a>> { - self.watchers.watch(prefix) + #[inline] + pub fn watch_prefix<'a, K>(&'a self, prefix: &K) -> Pin + Send + 'a>> + where + K: AsRef<[u8]> + ?Sized + Debug, + { + self.watchers.watch(prefix.as_ref()) } + #[inline] pub fn property_integer(&self, name: &CStr) -> Result { self.db.property_integer(&self.cf(), name) } + #[inline] pub fn property(&self, name: &str) -> Result { self.db.property(&self.cf(), name) } #[inline] @@ -199,12 +217,12 @@ impl Map { fn cf(&self) -> impl AsColumnFamilyRef + '_ { &*self.cf } } -impl<'a> IntoIterator for &'a Map { - type IntoIter = Box + Send + 'a>; - type Item = OwnedKeyValPair; +impl Debug for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "Map {{name: {0}}}", self.name) } +} - #[inline] - fn into_iter(self) -> Self::IntoIter { self.iter() } +impl Display for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "{0}", self.name) } } fn open(db: &Arc, name: &str) -> Result> { diff --git a/src/database/map/count.rs b/src/database/map/count.rs new file mode 100644 index 000000000..4356b71f5 --- /dev/null +++ b/src/database/map/count.rs @@ -0,0 +1,36 @@ +use std::{fmt::Debug, future::Future}; + +use conduit::implement; +use futures::stream::StreamExt; +use serde::Serialize; + +use crate::de::Ignore; + +/// Count the total number of entries in the map. +#[implement(super::Map)] +#[inline] +pub fn count(&self) -> impl Future + Send + '_ { self.keys::().count() } + +/// Count the number of entries in the map starting from a lower-bound. +/// +/// - From is a structured key +#[implement(super::Map)] +#[inline] +pub fn count_from<'a, P>(&'a self, from: &P) -> impl Future + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_from::(from).count() +} + +/// Count the number of entries in the map matching a prefix. +/// +/// - Prefix is structured key +#[implement(super::Map)] +#[inline] +pub fn count_prefix<'a, P>(&'a self, prefix: &P) -> impl Future + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_prefix::(prefix).count() +} diff --git a/src/database/map/keys.rs b/src/database/map/keys.rs new file mode 100644 index 000000000..2396494c4 --- /dev/null +++ b/src/database/map/keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys<'a, K>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, +{ + self.raw_keys().map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/keys_from.rs b/src/database/map/keys_from.rs new file mode 100644 index 000000000..1993750ab --- /dev/null +++ b/src/database/map/keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_raw_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.raw_keys_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/keys_prefix.rs b/src/database/map/keys_prefix.rs new file mode 100644 index 000000000..d6c0927b9 --- /dev/null +++ b/src/database/map/keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_raw_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/rev_keys.rs b/src/database/map/rev_keys.rs new file mode 100644 index 000000000..449ccfff3 --- /dev/null +++ b/src/database/map/rev_keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys<'a, K>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys().map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_keys_from.rs b/src/database/map/rev_keys_from.rs new file mode 100644 index 000000000..e012e60af --- /dev/null +++ b/src/database/map/rev_keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_raw_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_keys_prefix.rs b/src/database/map/rev_keys_prefix.rs new file mode 100644 index 000000000..162c4f9b8 --- /dev/null +++ b/src/database/map/rev_keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_raw_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/rev_stream.rs b/src/database/map/rev_stream.rs new file mode 100644 index 000000000..de22fd5ce --- /dev/null +++ b/src/database/map/rev_stream.rs @@ -0,0 +1,29 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream<'a, K, V>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream() + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_stream_from.rs b/src/database/map/rev_stream_from.rs new file mode 100644 index 000000000..650cf038c --- /dev/null +++ b/src/database/map/rev_stream_from.rs @@ -0,0 +1,68 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_stream_raw_from(&key) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream_from(from) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_stream_prefix.rs b/src/database/map/rev_stream_prefix.rs new file mode 100644 index 000000000..9ef89e9cb --- /dev/null +++ b/src/database/map/rev_stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_stream_raw_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix_raw<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_stream_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/stream.rs b/src/database/map/stream.rs new file mode 100644 index 000000000..dfbea0729 --- /dev/null +++ b/src/database/map/stream.rs @@ -0,0 +1,28 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream<'a, K, V>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream().map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/stream_from.rs b/src/database/map/stream_from.rs new file mode 100644 index 000000000..153d5bb61 --- /dev/null +++ b/src/database/map/stream_from.rs @@ -0,0 +1,68 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.stream_raw_from(&key) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream_from(from) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/stream_prefix.rs b/src/database/map/stream_prefix.rs new file mode 100644 index 000000000..56154a8b3 --- /dev/null +++ b/src/database/map/stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.stream_raw_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix_raw<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.raw_stream_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 6446624ca..e66abf682 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,25 +1,35 @@ mod cork; mod database; +mod de; +mod deserialized; mod engine; mod handle; -mod iter; +pub mod keyval; mod map; pub mod maps; mod opts; -mod slice; +mod ser; +mod stream; mod util; mod watchers; +pub(crate) use self::{ + engine::Engine, + util::{or_else, result}, +}; + extern crate conduit_core as conduit; extern crate rust_rocksdb as rocksdb; -pub use database::Database; -pub(crate) use engine::Engine; -pub use handle::Handle; -pub use iter::Iter; -pub use map::Map; -pub use slice::{Key, KeyVal, OwnedKey, OwnedKeyVal, OwnedVal, Val}; -pub(crate) use util::{or_else, result}; +pub use self::{ + database::Database, + de::Ignore, + deserialized::Deserialized, + handle::Handle, + keyval::{KeyVal, Slice}, + map::Map, + ser::{Interfix, Separator}, +}; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/database/ser.rs b/src/database/ser.rs new file mode 100644 index 000000000..bd4bbd9ad --- /dev/null +++ b/src/database/ser.rs @@ -0,0 +1,315 @@ +use std::io::Write; + +use conduit::{err, result::DebugInspect, utils::exchange, Error, Result}; +use serde::{ser, Serialize}; + +#[inline] +pub(crate) fn serialize_to_vec(val: &T) -> Result> +where + T: Serialize + ?Sized, +{ + let mut buf = Vec::with_capacity(64); + serialize(&mut buf, val)?; + + Ok(buf) +} + +#[inline] +pub(crate) fn serialize<'a, W, T>(out: &'a mut W, val: &'a T) -> Result<&'a [u8]> +where + W: Write + AsRef<[u8]>, + T: Serialize + ?Sized, +{ + let mut serializer = Serializer { + out, + depth: 0, + sep: false, + fin: false, + }; + + val.serialize(&mut serializer) + .map_err(|error| err!(SerdeSer("{error}"))) + .debug_inspect(|()| { + debug_assert_eq!(serializer.depth, 0, "Serialization completed at non-zero recursion level"); + })?; + + Ok((*out).as_ref()) +} + +pub(crate) struct Serializer<'a, W: Write> { + out: &'a mut W, + depth: u32, + sep: bool, + fin: bool, +} + +/// Directive to force separator serialization specifically for prefix keying +/// use. This is a quirk of the database schema and prefix iterations. +#[derive(Debug, Serialize)] +pub struct Interfix; + +/// Directive to force separator serialization. Separators are usually +/// serialized automatically. +#[derive(Debug, Serialize)] +pub struct Separator; + +impl Serializer<'_, W> { + const SEP: &'static [u8] = b"\xFF"; + + fn sequence_start(&mut self) { + debug_assert!(!self.is_finalized(), "Sequence start with finalization set"); + debug_assert!(!self.sep, "Sequence start with separator set"); + if cfg!(debug_assertions) { + self.depth = self.depth.saturating_add(1); + } + } + + fn sequence_end(&mut self) { + self.sep = false; + if cfg!(debug_assertions) { + self.depth = self.depth.saturating_sub(1); + } + } + + fn record_start(&mut self) -> Result<()> { + debug_assert!(!self.is_finalized(), "Starting a record after serialization finalized"); + exchange(&mut self.sep, true) + .then(|| self.separator()) + .unwrap_or(Ok(())) + } + + fn separator(&mut self) -> Result<()> { + debug_assert!(!self.is_finalized(), "Writing a separator after serialization finalized"); + self.out.write_all(Self::SEP).map_err(Into::into) + } + + fn set_finalized(&mut self) { + debug_assert!(!self.is_finalized(), "Finalization already set"); + if cfg!(debug_assertions) { + self.fin = true; + } + } + + fn is_finalized(&self) -> bool { self.fin } +} + +impl ser::Serializer for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + type SerializeMap = Self; + type SerializeSeq = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + + fn serialize_map(self, _len: Option) -> Result { + unimplemented!("serialize Map not implemented") + } + + fn serialize_seq(self, _len: Option) -> Result { + self.sequence_start(); + self.record_start()?; + Ok(self) + } + + fn serialize_tuple(self, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_tuple_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_struct_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result { + unimplemented!("serialize New Type Struct not implemented") + } + + fn serialize_newtype_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _value: &T, + ) -> Result { + unimplemented!("serialize New Type Variant not implemented") + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + match name { + "Interfix" => { + self.set_finalized(); + }, + "Separator" => { + self.separator()?; + }, + _ => unimplemented!("Unrecognized serialization directive: {name:?}"), + }; + + Ok(()) + } + + fn serialize_unit_variant(self, _name: &'static str, _idx: u32, _var: &'static str) -> Result { + unimplemented!("serialize Unit Variant not implemented") + } + + fn serialize_some(self, val: &T) -> Result { val.serialize(self) } + + fn serialize_none(self) -> Result { Ok(()) } + + fn serialize_char(self, v: char) -> Result { + let mut buf: [u8; 4] = [0; 4]; + self.serialize_str(v.encode_utf8(&mut buf)) + } + + fn serialize_str(self, v: &str) -> Result { self.serialize_bytes(v.as_bytes()) } + + fn serialize_bytes(self, v: &[u8]) -> Result { self.out.write_all(v).map_err(Error::Io) } + + fn serialize_f64(self, _v: f64) -> Result { unimplemented!("serialize f64 not implemented") } + + fn serialize_f32(self, _v: f32) -> Result { unimplemented!("serialize f32 not implemented") } + + fn serialize_i64(self, v: i64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + + fn serialize_i32(self, _v: i32) -> Result { unimplemented!("serialize i32 not implemented") } + + fn serialize_i16(self, _v: i16) -> Result { unimplemented!("serialize i16 not implemented") } + + fn serialize_i8(self, _v: i8) -> Result { unimplemented!("serialize i8 not implemented") } + + fn serialize_u64(self, v: u64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + + fn serialize_u32(self, _v: u32) -> Result { unimplemented!("serialize u32 not implemented") } + + fn serialize_u16(self, _v: u16) -> Result { unimplemented!("serialize u16 not implemented") } + + fn serialize_u8(self, v: u8) -> Result { self.out.write_all(&[v]).map_err(Error::Io) } + + fn serialize_bool(self, _v: bool) -> Result { unimplemented!("serialize bool not implemented") } + + fn serialize_unit(self) -> Result { unimplemented!("serialize unit not implemented") } +} + +impl ser::SerializeMap for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_key(&mut self, _key: &T) -> Result { + unimplemented!("serialize Map Key not implemented") + } + + fn serialize_value(&mut self, _val: &T) -> Result { + unimplemented!("serialize Map Val not implemented") + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeSeq for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { val.serialize(&mut **self) } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeStructVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTuple for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTupleStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTupleVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} diff --git a/src/database/slice.rs b/src/database/slice.rs deleted file mode 100644 index 448d969d9..000000000 --- a/src/database/slice.rs +++ /dev/null @@ -1,57 +0,0 @@ -pub struct OwnedKeyVal(pub OwnedKey, pub OwnedVal); -pub(crate) type OwnedKeyValPair = (OwnedKey, OwnedVal); -pub type OwnedVal = Vec; -pub type OwnedKey = Vec; - -pub struct KeyVal<'item>(pub &'item Key, pub &'item Val); -pub(crate) type KeyValPair<'item> = (&'item Key, &'item Val); -pub type Val = [Byte]; -pub type Key = [Byte]; - -pub(crate) type Byte = u8; - -impl OwnedKeyVal { - #[must_use] - pub fn as_slice(&self) -> KeyVal<'_> { KeyVal(&self.0, &self.1) } - - #[must_use] - pub fn to_tuple(self) -> OwnedKeyValPair { (self.0, self.1) } -} - -impl From for OwnedKeyVal { - fn from((key, val): OwnedKeyValPair) -> Self { Self(key, val) } -} - -impl From<&KeyVal<'_>> for OwnedKeyVal { - #[inline] - fn from(slice: &KeyVal<'_>) -> Self { slice.to_owned() } -} - -impl From> for OwnedKeyVal { - fn from((key, val): KeyValPair<'_>) -> Self { Self(Vec::from(key), Vec::from(val)) } -} - -impl From for OwnedKeyValPair { - fn from(val: OwnedKeyVal) -> Self { val.to_tuple() } -} - -impl KeyVal<'_> { - #[inline] - #[must_use] - pub fn to_owned(&self) -> OwnedKeyVal { OwnedKeyVal::from(self) } - - #[must_use] - pub fn as_tuple(&self) -> KeyValPair<'_> { (self.0, self.1) } -} - -impl<'a> From<&'a OwnedKeyVal> for KeyVal<'a> { - fn from(owned: &'a OwnedKeyVal) -> Self { owned.as_slice() } -} - -impl<'a> From<&'a OwnedKeyValPair> for KeyVal<'a> { - fn from((key, val): &'a OwnedKeyValPair) -> Self { KeyVal(key.as_slice(), val.as_slice()) } -} - -impl<'a> From> for KeyVal<'a> { - fn from((key, val): KeyValPair<'a>) -> Self { KeyVal(key, val) } -} diff --git a/src/database/stream.rs b/src/database/stream.rs new file mode 100644 index 000000000..d9b74215d --- /dev/null +++ b/src/database/stream.rs @@ -0,0 +1,122 @@ +mod items; +mod items_rev; +mod keys; +mod keys_rev; + +use std::sync::Arc; + +use conduit::{utils::exchange, Error, Result}; +use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, ReadOptions}; + +pub(crate) use self::{items::Items, items_rev::ItemsRev, keys::Keys, keys_rev::KeysRev}; +use crate::{ + engine::Db, + keyval::{Key, KeyVal, Val}, + util::map_err, + Engine, Slice, +}; + +struct State<'a> { + inner: Inner<'a>, + seek: bool, + init: bool, +} + +trait Cursor<'a, T> { + fn state(&self) -> &State<'a>; + + fn fetch(&self) -> Option; + + fn seek(&mut self); + + fn get(&self) -> Option> { + self.fetch() + .map(Ok) + .or_else(|| self.state().status().map(Err)) + } + + fn seek_and_get(&mut self) -> Option> { + self.seek(); + self.get() + } +} + +type Inner<'a> = DBRawIteratorWithThreadMode<'a, Db>; +type From<'a> = Option>; + +impl<'a> State<'a> { + fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions) -> Self { + Self { + inner: db.db.raw_iterator_cf_opt(&**cf, opts), + init: true, + seek: false, + } + } + + fn init_fwd(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek(key); + self.seek = true; + } + + self + } + + fn init_rev(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek_for_prev(key); + self.seek = true; + } + + self + } + + fn seek_fwd(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.next(); + } else if !self.seek { + self.inner.seek_to_first(); + } + } + + fn seek_rev(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.prev(); + } else if !self.seek { + self.inner.seek_to_last(); + } + } + + fn fetch_key(&self) -> Option> { self.inner.key().map(Key::from) } + + fn _fetch_val(&self) -> Option> { self.inner.value().map(Val::from) } + + fn fetch(&self) -> Option> { self.inner.item().map(KeyVal::from) } + + fn status(&self) -> Option { self.inner.status().map_err(map_err).err() } + + fn valid(&self) -> bool { self.inner.valid() } +} + +fn keyval_longevity<'a, 'b: 'a>(item: KeyVal<'a>) -> KeyVal<'b> { + (slice_longevity::<'a, 'b>(item.0), slice_longevity::<'a, 'b>(item.1)) +} + +fn slice_longevity<'a, 'b: 'a>(item: &'a Slice) -> &'b Slice { + // SAFETY: The lifetime of the data returned by the rocksdb cursor is only valid + // between each movement of the cursor. It is hereby unsafely extended to match + // the lifetime of the cursor itself. This is due to the limitation of the + // Stream trait where the Item is incapable of conveying a lifetime; this is due + // to GAT's being unstable during its development. This unsafety can be removed + // as soon as this limitation is addressed by an upcoming version. + // + // We have done our best to mitigate the implications of this in conjunction + // with the deserialization API such that borrows being held across movements of + // the cursor do not happen accidentally. The compiler will still error when + // values herein produced try to leave a closure passed to a StreamExt API. But + // escapes can happen if you explicitly and intentionally attempt it, and there + // will be no compiler error or warning. This is primarily the case with + // calling collect() without a preceding map(ToOwned::to_owned). A collection + // of references here is illegal, but this will not be enforced by the compiler. + unsafe { std::mem::transmute(item) } +} diff --git a/src/database/stream/items.rs b/src/database/stream/items.rs new file mode 100644 index 000000000..31d5e9e8d --- /dev/null +++ b/src/database/stream/items.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct Items<'a> { + state: State<'a>, +} + +impl<'a> Items<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for Items<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Items<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Items<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/items_rev.rs b/src/database/stream/items_rev.rs new file mode 100644 index 000000000..ab57a2506 --- /dev/null +++ b/src/database/stream/items_rev.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct ItemsRev<'a> { + state: State<'a>, +} + +impl<'a> ItemsRev<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for ItemsRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for ItemsRev<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for ItemsRev<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys.rs b/src/database/stream/keys.rs new file mode 100644 index 000000000..1c5d12e30 --- /dev/null +++ b/src/database/stream/keys.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct Keys<'a> { + state: State<'a>, +} + +impl<'a> Keys<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for Keys<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Keys<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Keys<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys_rev.rs b/src/database/stream/keys_rev.rs new file mode 100644 index 000000000..267074837 --- /dev/null +++ b/src/database/stream/keys_rev.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct KeysRev<'a> { + state: State<'a>, +} + +impl<'a> KeysRev<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for KeysRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for KeysRev<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for KeysRev<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/util.rs b/src/database/util.rs index f0ccbcbee..d36e183f4 100644 --- a/src/database/util.rs +++ b/src/database/util.rs @@ -1,4 +1,16 @@ use conduit::{err, Result}; +use rocksdb::{Direction, IteratorMode}; + +#[inline] +pub(crate) fn _into_direction(mode: &IteratorMode<'_>) -> Direction { + use Direction::{Forward, Reverse}; + use IteratorMode::{End, From, Start}; + + match mode { + Start | From(_, Forward) => Forward, + End | From(_, Reverse) => Reverse, + } +} #[inline] pub(crate) fn result(r: std::result::Result) -> Result { diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index cfed5a0e3..737a70399 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -46,7 +46,7 @@ bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hickory-resolver.workspace = true http.workspace = true image.workspace = true diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs deleted file mode 100644 index 53a0e9533..000000000 --- a/src/service/account_data/data.rs +++ /dev/null @@ -1,152 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use conduit::{Error, Result}; -use database::Map; -use ruma::{ - api::client::error::ErrorKind, - events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - roomuserdataid_accountdata: Arc, - roomusertype_roomuserdataid: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(), - roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - /// Places one event in the account data of the user and removes the - /// previous entry. - pub(super) fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: &RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - roomuserdataid.push(0xFF); - roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); - - let mut key = prefix; - key.extend_from_slice(event_type.to_string().as_bytes()); - - if data.get("type").is_none() || data.get("content").is_none() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Account data doesn't have all required fields.", - )); - } - - self.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), - )?; - - let prev = self.roomusertype_roomuserdataid.get(&key)?; - - self.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - - // Remove old entry - if let Some(prev) = prev { - self.roomuserdataid_accountdata.remove(&prev)?; - } - - Ok(()) - } - - /// Searches the account data for a specific kind. - pub(super) fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: &RoomAccountDataEventType, - ) -> Result>> { - let mut key = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(kind.to_string().as_bytes()); - - self.roomusertype_roomuserdataid - .get(&key)? - .and_then(|roomuserdataid| { - self.roomuserdataid_accountdata - .get(&roomuserdataid) - .transpose() - }) - .transpose()? - .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) - .transpose() - } - - /// Returns all changes to the account data that happened after `since`. - pub(super) fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result> { - let mut userdata = HashMap::new(); - - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - // Skip the data that's exactly at since, because we sent that last time - let mut first_possible = prefix.clone(); - first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); - - for r in self - .roomuserdataid_accountdata - .iter_from(&first_possible, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(k, v)| { - Ok::<_, Error>(( - k, - match room_id { - None => serde_json::from_slice::>(&v) - .map(AnyRawAccountDataEvent::Global) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - Some(_) => serde_json::from_slice::>(&v) - .map(AnyRawAccountDataEvent::Room) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - }, - )) - }) { - let (kind, data) = r?; - userdata.insert(kind, data); - } - - Ok(userdata.into_values().collect()) - } -} diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index eaa536417..b4eb143d4 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -1,52 +1,158 @@ -mod data; +use std::{collections::HashMap, sync::Arc}; -use std::sync::Arc; - -use conduit::Result; -use data::Data; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Map}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ - events::{AnyRawAccountDataEvent, RoomAccountDataEventType}, + events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, + serde::Raw, RoomId, UserId, }; +use serde_json::value::RawValue; + +use crate::{globals, Dep}; pub struct Service { + services: Services, db: Data, } +struct Data { + roomuserdataid_accountdata: Arc, + roomusertype_roomuserdataid: Arc, +} + +struct Services { + globals: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + services: Services { + globals: args.depend::("globals"), + }, + db: Data { + roomuserdataid_accountdata: args.db["roomuserdataid_accountdata"].clone(), + roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Places one event in the account data of the user and removes the - /// previous entry. - #[allow(clippy::needless_pass_by_value)] - pub fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - self.db.update(room_id, user_id, &event_type, data) - } +/// Places one event in the account data of the user and removes the +/// previous entry. +#[allow(clippy::needless_pass_by_value)] +#[implement(Service)] +pub async fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, +) -> Result<()> { + let event_type = event_type.to_string(); + let count = self.services.globals.next_count()?; + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); - /// Searches the account data for a specific kind. - #[allow(clippy::needless_pass_by_value)] - pub fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - ) -> Result>> { - self.db.get(room_id, user_id, &event_type) + let mut roomuserdataid = prefix.clone(); + roomuserdataid.extend_from_slice(&count.to_be_bytes()); + roomuserdataid.push(0xFF); + roomuserdataid.extend_from_slice(event_type.as_bytes()); + + let mut key = prefix; + key.extend_from_slice(event_type.as_bytes()); + + if data.get("type").is_none() || data.get("content").is_none() { + return Err!(Request(InvalidParam("Account data doesn't have all required fields."))); } - /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip_all, name = "since", level = "debug")] - pub fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result> { - self.db.changes_since(room_id, user_id, since) + self.db.roomuserdataid_accountdata.insert( + &roomuserdataid, + &serde_json::to_vec(&data).expect("to_vec always works on json values"), + ); + + let prev_key = (room_id, user_id, &event_type); + let prev = self.db.roomusertype_roomuserdataid.qry(&prev_key).await; + + self.db + .roomusertype_roomuserdataid + .insert(&key, &roomuserdataid); + + // Remove old entry + if let Ok(prev) = prev { + self.db.roomuserdataid_accountdata.remove(&prev); } + + Ok(()) +} + +/// Searches the account data for a specific kind. +#[implement(Service)] +pub async fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, +) -> Result> { + let key = (room_id, user_id, kind.to_string()); + self.db + .roomusertype_roomuserdataid + .qry(&key) + .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.qry(&roomuserdataid)) + .await + .deserialized_json() +} + +/// Returns all changes to the account data that happened after `since`. +#[implement(Service)] +pub async fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, +) -> Result> { + let mut userdata = HashMap::new(); + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); + + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); + + self.db + .roomuserdataid_accountdata + .raw_stream_from(&first_possible) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(k, v)| { + let v = match room_id { + None => serde_json::from_slice::>(v) + .map(AnyRawAccountDataEvent::Global) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + Some(_) => serde_json::from_slice::>(v) + .map(AnyRawAccountDataEvent::Room) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + }; + + Ok((k.to_owned(), v)) + }) + .ignore_err() + .ready_for_each(|(kind, data)| { + userdata.insert(kind, data); + }) + .await; + + Ok(userdata.into_values().collect()) } diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 55bae3658..0f5016e15 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -5,7 +5,7 @@ use std::{ }; use conduit::{debug, defer, error, log, Server}; -use futures_util::future::{AbortHandle, Abortable}; +use futures::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; use termimad::MadSkin; diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index 4e2b831c5..7b090aa0b 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -30,7 +30,7 @@ use crate::Services; pub async fn create_admin_room(services: &Services) -> Result<()> { let room_id = RoomId::new(services.globals.server_name()); - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id); let state_lock = services.rooms.state.mutex.lock(&room_id).await; diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index b4589ebc8..4b3ebb887 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -17,108 +17,108 @@ use serde_json::value::to_raw_value; use crate::pdu::PduBuilder; -impl super::Service { - /// Invite the user to the conduit admin room. - /// - /// In conduit, this is equivalent to granting admin privileges. - pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { - let Some(room_id) = self.get_admin_room()? else { - return Ok(()); - }; +/// Invite the user to the conduit admin room. +/// +/// In conduit, this is equivalent to granting admin privileges. +#[implement(super::Service)] +pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { + let Ok(room_id) = self.get_admin_room().await else { + return Ok(()); + }; - let state_lock = self.services.state.mutex.lock(&room_id).await; + let state_lock = self.services.state.mutex.lock(&room_id).await; - // Use the server user to grant the new admin's power level - let server_user = &self.services.globals.server_user; + // Use the server user to grant the new admin's power level + let server_user = &self.services.globals.server_user; - // Invite and join the real user - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ) - .await?; - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - user_id, - &room_id, - &state_lock, - ) - .await?; + // Invite and join the real user + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + server_user, + &room_id, + &state_lock, + ) + .await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + user_id, + &room_id, + &state_lock, + ) + .await?; - // Set power level - let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); + // Set power level + let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ) - .await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(String::new()), + redacts: None, + timestamp: None, + }, + server_user, + &room_id, + &state_lock, + ) + .await?; - // Set room tag - let room_tag = &self.services.server.config.admin_room_tag; - if !room_tag.is_empty() { - if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag) { - error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); - } + // Set room tag + let room_tag = &self.services.server.config.admin_room_tag; + if !room_tag.is_empty() { + if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag).await { + error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); } + } - // Send welcome message - self.services.timeline.build_and_append_pdu( + // Send welcome message + self.services.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_markdown( @@ -135,19 +135,18 @@ impl super::Service { &state_lock, ).await?; - Ok(()) - } + Ok(()) } #[implement(super::Service)] -fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { +async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { let mut event = self .services .account_data - .get(Some(room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| serde_json::from_str(event.get())) - .and_then(Result::ok) - .unwrap_or_else(|| TagEvent { + .get(Some(room_id), user_id, RoomAccountDataEventType::Tag) + .await + .and_then(|event| serde_json::from_str(event.get()).map_err(Into::into)) + .unwrap_or_else(|_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -158,12 +157,15 @@ fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result< .tags .insert(tag.to_owned().into(), TagInfo::new()); - self.services.account_data.update( - Some(room_id), - user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(event)?, - )?; + self.services + .account_data + .update( + Some(room_id), + user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(event)?, + ) + .await?; Ok(()) } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 3274249e6..12eacc8fa 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -12,6 +12,7 @@ use std::{ use async_trait::async_trait; use conduit::{debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server}; pub use create::create_admin_room; +use futures::{FutureExt, TryFutureExt}; use loole::{Receiver, Sender}; use ruma::{ events::{ @@ -142,17 +143,18 @@ impl Service { /// admin room as the admin user. pub async fn send_text(&self, body: &str) { self.send_message(RoomMessageEventContent::text_markdown(body)) - .await; + .await + .ok(); } /// Sends a message to the admin room as the admin user (see send_text() for /// convenience). - pub async fn send_message(&self, message_content: RoomMessageEventContent) { - if let Ok(Some(room_id)) = self.get_admin_room() { - let user_id = &self.services.globals.server_user; - self.respond_to_room(message_content, &room_id, user_id) - .await; - } + pub async fn send_message(&self, message_content: RoomMessageEventContent) -> Result<()> { + let user_id = &self.services.globals.server_user; + let room_id = self.get_admin_room().await?; + self.respond_to_room(message_content, &room_id, user_id) + .boxed() + .await } /// Posts a command to the command processor queue and returns. Processing @@ -193,8 +195,12 @@ impl Service { async fn handle_command(&self, command: CommandInput) { match self.process_command(command).await { - Ok(Some(output)) | Err(output) => self.handle_response(output).await, Ok(None) => debug!("Command successful with no response"), + Ok(Some(output)) | Err(output) => self + .handle_response(output) + .boxed() + .await + .unwrap_or_else(default_log), } } @@ -218,71 +224,67 @@ impl Service { } /// Checks whether a given user is an admin of this server - pub async fn user_is_admin(&self, user_id: &UserId) -> Result { - if let Ok(Some(admin_room)) = self.get_admin_room() { - self.services.state_cache.is_joined(user_id, &admin_room) - } else { - Ok(false) - } + pub async fn user_is_admin(&self, user_id: &UserId) -> bool { + let Ok(admin_room) = self.get_admin_room().await else { + return false; + }; + + self.services + .state_cache + .is_joined(user_id, &admin_room) + .await } /// Gets the room ID of the admin room /// /// Errors are propagated from the database, and will have None if there is /// no admin room - pub fn get_admin_room(&self) -> Result> { - if let Some(room_id) = self + pub async fn get_admin_room(&self) -> Result { + let room_id = self .services .alias - .resolve_local_alias(&self.services.globals.admin_alias)? - { - if self - .services - .state_cache - .is_joined(&self.services.globals.server_user, &room_id)? - { - return Ok(Some(room_id)); - } - } + .resolve_local_alias(&self.services.globals.admin_alias) + .await?; - Ok(None) + self.services + .state_cache + .is_joined(&self.services.globals.server_user, &room_id) + .await + .then_some(room_id) + .ok_or_else(|| err!(Request(NotFound("Admin user not joined to admin room")))) } - async fn handle_response(&self, content: RoomMessageEventContent) { + async fn handle_response(&self, content: RoomMessageEventContent) -> Result<()> { let Some(Relation::Reply { in_reply_to, }) = content.relates_to.as_ref() else { - return; + return Ok(()); }; - let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else { + let Ok(pdu) = self.services.timeline.get_pdu(&in_reply_to.event_id).await else { error!( event_id = ?in_reply_to.event_id, "Missing admin command in_reply_to event" ); - return; + return Ok(()); }; - let response_sender = if self.is_admin_room(&pdu.room_id) { + let response_sender = if self.is_admin_room(&pdu.room_id).await { &self.services.globals.server_user } else { &pdu.sender }; self.respond_to_room(content, &pdu.room_id, response_sender) - .await; + .await } - async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId) { - assert!( - self.user_is_admin(user_id) - .await - .expect("checked user is admin"), - "sender is not admin" - ); + async fn respond_to_room( + &self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId, + ) -> Result<()> { + assert!(self.user_is_admin(user_id).await, "sender is not admin"); - let state_lock = self.services.state.mutex.lock(room_id).await; let response_pdu = PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -292,6 +294,7 @@ impl Service { timestamp: None, }; + let state_lock = self.services.state.mutex.lock(room_id).await; if let Err(e) = self .services .timeline @@ -302,6 +305,8 @@ impl Service { .await .unwrap_or_else(default_log); } + + Ok(()) } async fn handle_response_error( @@ -355,12 +360,12 @@ impl Service { } // Prevent unescaped !admin from being used outside of the admin room - if is_public_prefix && !self.is_admin_room(&pdu.room_id) { + if is_public_prefix && !self.is_admin_room(&pdu.room_id).await { return false; } // Only senders who are admin can proceed - if !self.user_is_admin(&pdu.sender).await.unwrap_or(false) { + if !self.user_is_admin(&pdu.sender).await { return false; } @@ -368,7 +373,7 @@ impl Service { // the administrator can execute commands as conduit let emergency_password_set = self.services.globals.emergency_password().is_some(); let from_server = pdu.sender == *server_user && !emergency_password_set; - if from_server && self.is_admin_room(&pdu.room_id) { + if from_server && self.is_admin_room(&pdu.room_id).await { return false; } @@ -377,12 +382,11 @@ impl Service { } #[must_use] - pub fn is_admin_room(&self, room_id: &RoomId) -> bool { - if let Ok(Some(admin_room_id)) = self.get_admin_room() { - admin_room_id == room_id - } else { - false - } + pub async fn is_admin_room(&self, room_id_: &RoomId) -> bool { + self.get_admin_room() + .map_ok(|room_id| room_id == room_id_) + .await + .unwrap_or(false) } /// Sets the self-reference to crate::Services which will provide context to diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 40e641a1e..d5fa5476f 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,7 +1,8 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{err, utils::stream::TryIgnore, Result}; +use database::{Database, Deserialized, Map}; +use futures::Stream; use ruma::api::appservice::Registration; pub struct Data { @@ -19,7 +20,7 @@ impl Data { pub(super) fn register_appservice(&self, yaml: &Registration) -> Result { let id = yaml.id.as_str(); self.id_appserviceregistrations - .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; + .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes()); Ok(id.to_owned()) } @@ -31,24 +32,19 @@ impl Data { /// * `service_name` - the name you send to register the service previously pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations - .remove(service_name.as_bytes())?; + .remove(service_name.as_bytes()); Ok(()) } - pub fn get_registration(&self, id: &str) -> Result> { + pub async fn get_registration(&self, id: &str) -> Result { self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes) - .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")) - }) - .transpose() + .qry(id) + .await + .deserialized_json() + .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) } - pub(super) fn iter_ids<'a>(&'a self) -> Result> + 'a>> { - Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { - utils::string_from_bytes(&id) - .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) - }))) + pub(super) fn iter_ids(&self) -> impl Stream + Send + '_ { + self.id_appserviceregistrations.keys().ignore_err() } } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index c0752d565..7e2dc7387 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -2,9 +2,10 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; +use async_trait::async_trait; use conduit::{err, Result}; use data::Data; -use futures_util::Future; +use futures::{Future, StreamExt, TryStreamExt}; use regex::RegexSet; use ruma::{ api::appservice::{Namespace, Registration}, @@ -126,13 +127,22 @@ struct Services { sending: Dep, } +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { - let mut registration_info = BTreeMap::new(); - let db = Data::new(args.db); + Ok(Arc::new(Self { + db: Data::new(args.db), + services: Services { + sending: args.depend::("sending"), + }, + registration_info: RwLock::new(BTreeMap::new()), + })) + } + + async fn worker(self: Arc) -> Result<()> { // Inserting registrations into cache - for appservice in iter_ids(&db)? { - registration_info.insert( + for appservice in iter_ids(&self.db).await? { + self.registration_info.write().await.insert( appservice.0, appservice .1 @@ -141,13 +151,7 @@ impl crate::Service for Service { ); } - Ok(Arc::new(Self { - db, - services: Services { - sending: args.depend::("sending"), - }, - registration_info: RwLock::new(registration_info), - })) + Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -155,7 +159,7 @@ impl crate::Service for Service { impl Service { #[inline] - pub fn all(&self) -> Result> { iter_ids(&self.db) } + pub async fn all(&self) -> Result> { iter_ids(&self.db).await } /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result { @@ -188,7 +192,8 @@ impl Service { // sending to the URL self.services .sending - .cleanup_events(service_name.to_owned())?; + .cleanup_events(service_name.to_owned()) + .await; Ok(()) } @@ -251,15 +256,9 @@ impl Service { } } -fn iter_ids(db: &Data) -> Result> { - db.iter_ids()? - .filter_map(Result::ok) - .map(move |id| { - Ok(( - id.clone(), - db.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - }) - .collect() +async fn iter_ids(db: &Data) -> Result> { + db.iter_ids() + .then(|id| async move { Ok((id.clone(), db.get_registration(&id).await?)) }) + .try_collect() + .await } diff --git a/src/service/emergency/mod.rs b/src/service/emergency/mod.rs index 1bb0843d4..98020bc29 100644 --- a/src/service/emergency/mod.rs +++ b/src/service/emergency/mod.rs @@ -33,6 +33,7 @@ impl crate::Service for Service { async fn worker(self: Arc) -> Result<()> { self.set_emergency_access() + .await .inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?; Ok(()) @@ -44,7 +45,7 @@ impl crate::Service for Service { impl Service { /// Sets the emergency password and push rules for the @conduit account in /// case emergency password is set - fn set_emergency_access(&self) -> Result { + async fn set_emergency_access(&self) -> Result { let conduit_user = &self.services.globals.server_user; self.services @@ -56,17 +57,20 @@ impl Service { None => (Ruleset::new(), false), }; - self.services.account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; + self.services + .account_data + .update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + ) + .await?; if pwd_set { warn!( @@ -75,7 +79,7 @@ impl Service { ); } else { // logs out any users still in the server service account and removes sessions - self.services.users.deactivate_account(conduit_user)?; + self.services.users.deactivate_account(conduit_user).await?; } Ok(pwd_set) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5b5d9f09d..3286e40c5 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -4,8 +4,8 @@ use std::{ }; use conduit::{trace, utils, Error, Result, Server}; -use database::{Database, Map}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use database::{Database, Deserialized, Map}; +use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, signatures::Ed25519KeyPair, @@ -83,7 +83,7 @@ impl Data { .checked_add(1) .expect("counter must not overflow u64"); - self.global.insert(COUNTER, &counter.to_be_bytes())?; + self.global.insert(COUNTER, &counter.to_be_bytes()); Ok(*counter) } @@ -102,7 +102,7 @@ impl Data { fn stored_count(global: &Arc) -> Result { global - .get(COUNTER)? + .get(COUNTER) .as_deref() .map_or(Ok(0_u64), utils::u64_from_bytes) } @@ -133,36 +133,18 @@ impl Data { futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - let short_roomid = self - .services - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else { + continue; + }; let roomid_bytes = room_id.as_bytes().to_vec(); let mut roomid_prefix = roomid_bytes.clone(); roomid_prefix.push(0xFF); - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push(Box::pin(async move { - let _result = self.services.typing.wait_for_update(&room_id).await; - })); - - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); - // Key changes futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); @@ -174,6 +156,19 @@ impl Data { self.roomusertype_roomuserdataid .watch_prefix(&roomuser_prefix), ); + + // PDUs + let short_roomid = short_roomid.to_be_bytes().to_vec(); + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + let typing_room_id = room_id.to_owned(); + let typing_wait_for_update = async move { + self.services.typing.wait_for_update(&typing_room_id).await; + }; + + futures.push(typing_wait_for_update.boxed()); + futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); } let mut globaluserdata_prefix = vec![0xFF]; @@ -190,12 +185,14 @@ impl Data { // One time keys futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); - futures.push(Box::pin(async move { + // Server shutdown + let server_shutdown = async move { while self.services.server.running() { - let _result = self.services.server.signal.subscribe().recv().await; + self.services.server.signal.subscribe().recv().await.ok(); } - })); + }; + futures.push(server_shutdown.boxed()); if !self.services.server.running() { return Ok(()); } @@ -209,10 +206,10 @@ impl Data { } pub fn load_keypair(&self) -> Result { - let keypair_bytes = self.global.get(b"keypair")?.map_or_else( - || { + let keypair_bytes = self.global.get(b"keypair").map_or_else( + |_| { let keypair = utils::generate_keypair(); - self.global.insert(b"keypair", &keypair)?; + self.global.insert(b"keypair", &keypair); Ok::<_, Error>(keypair) }, |val| Ok(val.to_vec()), @@ -241,7 +238,10 @@ impl Data { } #[inline] - pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + pub fn remove_keypair(&self) -> Result<()> { + self.global.remove(b"keypair"); + Ok(()) + } /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored /// in room version > 4 @@ -250,15 +250,15 @@ impl Data { /// /// This doesn't actually check that the keys provided are newer than the /// old set. - pub fn add_signing_key( + pub async fn add_signing_key( &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result> { + ) -> BTreeMap { // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + let signingkeys = self.server_signingkeys.qry(origin).await; let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { + .and_then(|keys| serde_json::from_slice(&keys).map_err(Into::into)) + .unwrap_or_else(|_| { // Just insert "now", it doesn't matter ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) }); @@ -275,7 +275,7 @@ impl Data { self.server_signingkeys.insert( origin.as_bytes(), &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; + ); let mut tree = keys.verify_keys; tree.extend( @@ -284,45 +284,38 @@ impl Data { .map(|old| (old.0, VerifyKey::new(old.1.key))), ); - Ok(tree) + tree } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result> { - let signingkeys = self - .signing_keys_for(origin)? - .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { + pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { + self.signing_keys_for(origin).await.map_or_else( + |_| Ok(BTreeMap::new()), + |keys: ServerSigningKeys| { let mut tree = keys.verify_keys; tree.extend( keys.old_verify_keys .into_iter() .map(|old| (old.0, VerifyKey::new(old.1.key))), ); - tree - }); - - Ok(signingkeys) + Ok(tree) + }, + ) } - pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()); - - Ok(signingkeys) + pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { + self.server_signingkeys + .qry(origin) + .await + .deserialized_json() } - pub fn database_version(&self) -> Result { - self.global.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) - }) - } + pub async fn database_version(&self) -> u64 { self.global.qry("version").await.deserialized().unwrap_or(0) } #[inline] pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.global.insert(b"version", &new_version.to_be_bytes())?; + self.global.insert(b"version", &new_version.to_be_bytes()); Ok(()) } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 66917520b..c7a732309 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -1,17 +1,15 @@ -use std::{ - collections::{HashMap, HashSet}, - fs::{self}, - io::Write, - mem::size_of, - sync::Arc, +use conduit::{ + debug_info, debug_warn, error, info, + result::NotFound, + utils::{stream::TryIgnore, IterStream, ReadyExt}, + warn, Err, Error, Result, }; - -use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Error, Result}; +use futures::{FutureExt, StreamExt}; use itertools::Itertools; use ruma::{ events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, push::Ruleset, - EventId, OwnedRoomId, RoomId, UserId, + UserId, }; use crate::{media, Services}; @@ -33,12 +31,14 @@ pub(crate) const DATABASE_VERSION: u64 = 13; pub(crate) const CONDUIT_DATABASE_VERSION: u64 = 16; pub(crate) async fn migrations(services: &Services) -> Result<()> { + let users_count = services.users.count().await; + // Matrix resource ownership is based on the server name; changing it // requires recreating the database from scratch. - if services.users.count()? > 0 { + if users_count > 0 { let conduit_user = &services.globals.server_user; - if !services.users.exists(conduit_user)? { + if !services.users.exists(conduit_user).await { error!("The {} server user does not exist, and the database is not new.", conduit_user); return Err(Error::bad_database( "Cannot reuse an existing database after changing the server name, please delete the old one first.", @@ -46,7 +46,7 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> { } } - if services.users.count()? > 0 { + if users_count > 0 { migrate(services).await } else { fresh(services).await @@ -62,9 +62,9 @@ async fn fresh(services: &Services) -> Result<()> { .db .bump_database_version(DATABASE_VERSION)?; - db["global"].insert(b"feat_sha256_media", &[])?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + db["global"].insert(b"feat_sha256_media", &[]); + db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); // Create the admin room and server user on first run crate::admin::create_admin_room(services).await?; @@ -82,566 +82,132 @@ async fn migrate(services: &Services) -> Result<()> { let db = &services.db; let config = &services.server.config; - if services.globals.db.database_version()? < 1 { - db_lt_1(services).await?; - } - - if services.globals.db.database_version()? < 2 { - db_lt_2(services).await?; + if services.globals.db.database_version().await < 11 { + return Err!(Database( + "Database schema version {} is no longer supported", + services.globals.db.database_version().await + )); } - if services.globals.db.database_version()? < 3 { - db_lt_3(services).await?; - } - - if services.globals.db.database_version()? < 4 { - db_lt_4(services).await?; - } - - if services.globals.db.database_version()? < 5 { - db_lt_5(services).await?; - } - - if services.globals.db.database_version()? < 6 { - db_lt_6(services).await?; - } - - if services.globals.db.database_version()? < 7 { - db_lt_7(services).await?; - } - - if services.globals.db.database_version()? < 8 { - db_lt_8(services).await?; - } - - if services.globals.db.database_version()? < 9 { - db_lt_9(services).await?; - } - - if services.globals.db.database_version()? < 10 { - db_lt_10(services).await?; - } - - if services.globals.db.database_version()? < 11 { - db_lt_11(services).await?; - } - - if services.globals.db.database_version()? < 12 { + if services.globals.db.database_version().await < 12 { db_lt_12(services).await?; } // This migration can be reused as-is anytime the server-default rules are // updated. - if services.globals.db.database_version()? < 13 { + if services.globals.db.database_version().await < 13 { db_lt_13(services).await?; } - if db["global"].get(b"feat_sha256_media")?.is_none() { + if db["global"].qry("feat_sha256_media").await.is_not_found() { media::migrations::migrate_sha256_media(services).await?; } else if config.media_startup_check { media::migrations::checkup_sha256_media(services).await?; } if db["global"] - .get(b"fix_bad_double_separator_in_state_cache")? - .is_none() + .qry("fix_bad_double_separator_in_state_cache") + .await + .is_not_found() { fix_bad_double_separator_in_state_cache(services).await?; } if db["global"] - .get(b"retroactively_fix_bad_data_from_roomuserid_joined")? - .is_none() + .qry("retroactively_fix_bad_data_from_roomuserid_joined") + .await + .is_not_found() { retroactively_fix_bad_data_from_roomuserid_joined(services).await?; } - let version_match = services.globals.db.database_version().unwrap() == DATABASE_VERSION - || services.globals.db.database_version().unwrap() == CONDUIT_DATABASE_VERSION; + let version_match = services.globals.db.database_version().await == DATABASE_VERSION + || services.globals.db.database_version().await == CONDUIT_DATABASE_VERSION; assert!( version_match, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", - services.globals.db.database_version().unwrap(), + services.globals.db.database_version().await, DATABASE_VERSION, ); { let patterns = services.globals.forbidden_usernames(); if !patterns.is_empty() { - for user_id in services + services .users - .iter() - .filter_map(Result::ok) - .filter(|user| !services.users.is_deactivated(user).unwrap_or(true)) - .filter(|user| user.server_name() == config.server_name) - { - let matches = patterns.matches(user_id.localpart()); - if matches.matched_any() { - warn!( - "User {} matches the following forbidden username patterns: {}", - user_id.to_string(), - matches - .into_iter() - .map(|x| &patterns.patterns()[x]) - .join(", ") - ); - } - } - } - } - - { - let patterns = services.globals.forbidden_alias_names(); - if !patterns.is_empty() { - for address in services.rooms.metadata.iter_ids() { - let room_id = address?; - let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id); - for room_alias_result in room_aliases { - let room_alias = room_alias_result?; - let matches = patterns.matches(room_alias.alias()); + .stream() + .filter(|user_id| services.users.is_active_local(user_id)) + .ready_for_each(|user_id| { + let matches = patterns.matches(user_id.localpart()); if matches.matched_any() { warn!( - "Room with alias {} ({}) matches the following forbidden room name patterns: {}", - room_alias, - &room_id, + "User {} matches the following forbidden username patterns: {}", + user_id.to_string(), matches .into_iter() .map(|x| &patterns.patterns()[x]) .join(", ") ); } - } - } - } - } - - info!( - "Loaded {} database with schema version {DATABASE_VERSION}", - config.database_backend, - ); - - Ok(()) -} - -async fn db_lt_1(services: &Services) -> Result<()> { - let db = &services.db; - - let roomserverids = &db["roomserverids"]; - let serverroomids = &db["serverroomids"]; - for (roomserverid, _) in roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xFF); - let room_id = parts.next().expect("split always returns one element"); - let Some(servername) = parts.next() else { - error!("Migration: Invalid roomserverid in db."); - continue; - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xFF); - serverroomid.extend_from_slice(room_id); - - serverroomids.insert(&serverroomid, &[])?; - } - - services.globals.db.bump_database_version(1)?; - info!("Migration: 0 -> 1 finished"); - Ok(()) -} - -async fn db_lt_2(services: &Services) -> Result<()> { - let db = &services.db; - - // We accidentally inserted hashed versions of "" into the db instead of just "" - let userid_password = &db["roomserverids"]; - for (userid, password) in userid_password.iter() { - let empty_pass = utils::hash::password("").expect("our own password to be properly hashed"); - let password = std::str::from_utf8(&password).expect("password is valid utf-8"); - let empty_hashed_password = utils::hash::verify_password(password, &empty_pass).is_ok(); - if empty_hashed_password { - userid_password.insert(&userid, b"")?; + }) + .await; } } - services.globals.db.bump_database_version(2)?; - info!("Migration: 1 -> 2 finished"); - Ok(()) -} - -async fn db_lt_3(services: &Services) -> Result<()> { - let db = &services.db; - - // Move media to filesystem - let mediaid_file = &db["mediaid_file"]; - for (key, content) in mediaid_file.iter() { - if content.is_empty() { - continue; - } - - #[allow(deprecated)] - let path = services.media.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - mediaid_file.insert(&key, &[])?; - } - - services.globals.db.bump_database_version(3)?; - info!("Migration: 2 -> 3 finished"); - Ok(()) -} - -async fn db_lt_4(services: &Services) -> Result<()> { - let config = &services.server.config; - - // Add federated users to services as deactivated - for our_user in services.users.iter() { - let our_user = our_user?; - if services.users.is_deactivated(&our_user)? { - continue; - } - for room in services.rooms.state_cache.rooms_joined(&our_user) { - for user in services.rooms.state_cache.room_members(&room?) { - let user = user?; - if user.server_name() != config.server_name { - info!(?user, "Migration: creating user"); - services.users.create(&user, None)?; - } - } - } - } - - services.globals.db.bump_database_version(4)?; - info!("Migration: 3 -> 4 finished"); - Ok(()) -} - -async fn db_lt_5(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade user data store - let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"]; - let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"]; - for (roomuserdataid, _) in roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); - - let mut key = room_id.to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id); - key.push(0xFF); - key.extend_from_slice(event_type); - - roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; - } - - services.globals.db.bump_database_version(5)?; - info!("Migration: 4 -> 5 finished"); - Ok(()) -} - -async fn db_lt_6(services: &Services) -> Result<()> { - let db = &services.db; - - // Set room member count - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - for (roomid, _) in roomid_shortstatehash.iter() { - let string = utils::string_from_bytes(&roomid).unwrap(); - let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services.rooms.state_cache.update_joined_count(room_id)?; - } - - services.globals.db.bump_database_version(6)?; - info!("Migration: 5 -> 6 finished"); - Ok(()) -} - -async fn db_lt_7(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade state store - let mut last_roomstates: HashMap = HashMap::new(); - let mut current_sstatehash: Option = None; - let mut current_room = None; - let mut current_state = HashSet::new(); - - let handle_state = |current_sstatehash: u64, - current_room: &RoomId, - current_state: HashSet<_>, - last_roomstates: &mut HashMap<_, _>| { - let last_roomsstatehash = last_roomstates.get(current_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |&last_roomsstatehash| { + { + let patterns = services.globals.forbidden_alias_names(); + if !patterns.is_empty() { + for room_id in services + .rooms + .metadata + .iter_ids() + .map(ToOwned::to_owned) + .collect::>() + .await + { services .rooms - .state_compressor - .load_shortstatehash_info(last_roomsstatehash) - }, - )?; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .copied() - .collect::>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .copied() - .collect::>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - services.rooms.state_compressor.save_state_from_diff( - current_sstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, // every state change is 2 event changes on average - states_parents, - )?; - - /* - let mut tmp = services.rooms.load_shortstatehash_info(¤t_sstatehash)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>() - ); - */ - - Ok::<_, Error>(()) - }; - - let stateid_shorteventid = &db["stateid_shorteventid"]; - let shorteventid_eventid = &db["shorteventid_eventid"]; - for (k, seventid) in stateid_shorteventid.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]).expect("number of bytes is correct"); - let sstatekey = k[size_of::()..].to_vec(); - if Some(sstatehash) != current_sstatehash { - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash); - } - current_state = HashSet::new(); - current_sstatehash = Some(sstatehash); - - let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap(); - - if Some(&pdu.room_id) != current_room.as_ref() { - current_room = Some(pdu.room_id.clone()); + .alias + .local_aliases_for_room(&room_id) + .ready_for_each(|room_alias| { + let matches = patterns.matches(room_alias.alias()); + if matches.matched_any() { + warn!( + "Room with alias {} ({}) matches the following forbidden room name patterns: {}", + room_alias, + &room_id, + matches + .into_iter() + .map(|x| &patterns.patterns()[x]) + .join(", ") + ); + } + }) + .await; } } - - let mut val = sstatekey; - val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); - } - - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - } - - services.globals.db.bump_database_version(7)?; - info!("Migration: 6 -> 7 finished"); - Ok(()) -} - -async fn db_lt_8(services: &Services) -> Result<()> { - let db = &services.db; - - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - let pduid_pdu = &db["pduid_pdu"]; - let eventid_pduid = &db["eventid_pduid"]; - - // Generate short room ids for all rooms - for (room_id, _) in roomid_shortstatehash.iter() { - let shortroomid = services.globals.next_count()?.to_be_bytes(); - roomid_shortroomid.insert(&room_id, &shortroomid)?; - info!("Migration: 8"); - } - // Update pduids db layout - let batch = pduid_pdu - .iter() - .filter_map(|(key, v)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(count); - - Some(database::OwnedKeyVal(new_key, v)) - }) - .collect::>(); - - pduid_pdu.insert_batch(batch.iter().map(database::KeyVal::from))?; - - let batch2 = eventid_pduid - .iter() - .filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_value = short_room_id.to_vec(); - new_value.extend_from_slice(count); - - Some(database::OwnedKeyVal(k, new_value)) - }) - .collect::>(); - - eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; - - services.globals.db.bump_database_version(8)?; - info!("Migration: 7 -> 8 finished"); - Ok(()) -} - -async fn db_lt_9(services: &Services) -> Result<()> { - let db = &services.db; - - let tokenids = &db["tokenids"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - - // Update tokenids db layout - let mut iter = tokenids - .iter() - .filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(word); - new_key.push(0xFF); - new_key.extend_from_slice(pdu_id_count); - Some(database::OwnedKeyVal(new_key, Vec::::new())) - }) - .peekable(); - - while iter.peek().is_some() { - let batch = iter.by_ref().take(1000).collect::>(); - tokenids.insert_batch(batch.iter().map(database::KeyVal::from))?; - debug!("Inserted smaller batch"); } - info!("Deleting starts"); - - let batch2: Vec<_> = tokenids - .iter() - .filter_map(|(key, _)| { - if key.starts_with(b"!") { - Some(key) - } else { - None - } - }) - .collect(); - - for key in batch2 { - tokenids.remove(&key)?; - } - - services.globals.db.bump_database_version(9)?; - info!("Migration: 8 -> 9 finished"); - Ok(()) -} - -async fn db_lt_10(services: &Services) -> Result<()> { - let db = &services.db; - - let statekey_shortstatekey = &db["statekey_shortstatekey"]; - let shortstatekey_statekey = &db["shortstatekey_statekey"]; - - // Add other direction for shortstatekeys - for (statekey, shortstatekey) in statekey_shortstatekey.iter() { - shortstatekey_statekey.insert(&shortstatekey, &statekey)?; - } - - // Force E2EE device list updates so we can send them over federation - for user_id in services.users.iter().filter_map(Result::ok) { - services.users.mark_device_key_update(&user_id)?; - } - - services.globals.db.bump_database_version(10)?; - info!("Migration: 9 -> 10 finished"); - Ok(()) -} - -#[allow(unreachable_code)] -async fn db_lt_11(services: &Services) -> Result<()> { - error!("Dropping a column to clear data is not implemented yet."); - //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; - //userdevicesessionid_uiaarequest.clear()?; + info!( + "Loaded {} database with schema version {DATABASE_VERSION}", + config.database_backend, + ); - services.globals.db.bump_database_version(11)?; - info!("Migration: 10 -> 11 finished"); Ok(()) } async fn db_lt_12(services: &Services) -> Result<()> { let config = &services.server.config; - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { Ok(u) => u, Err(e) => { warn!("Invalid username {username}: {e}"); @@ -652,7 +218,7 @@ async fn db_lt_12(services: &Services) -> Result<()> { let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() + .await .expect("Username is invalid"); let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); @@ -694,12 +260,15 @@ async fn db_lt_12(services: &Services) -> Result<()> { } } - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; } services.globals.db.bump_database_version(12)?; @@ -710,8 +279,14 @@ async fn db_lt_12(services: &Services) -> Result<()> { async fn db_lt_13(services: &Services) -> Result<()> { let config = &services.server.config; - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { Ok(u) => u, Err(e) => { warn!("Invalid username {username}: {e}"); @@ -722,7 +297,7 @@ async fn db_lt_13(services: &Services) -> Result<()> { let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() + .await .expect("Username is invalid"); let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); @@ -733,12 +308,15 @@ async fn db_lt_13(services: &Services) -> Result<()> { .global .update_with_server_default(user_default_rules); - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; } services.globals.db.bump_database_version(13)?; @@ -754,32 +332,37 @@ async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result< let _cork = db.cork_and_sync(); let mut iter_count: usize = 0; - for (mut key, value) in roomuserid_joined.iter() { - iter_count = iter_count.saturating_add(1); - debug_info!(%iter_count); - let first_sep_index = key - .iter() - .position(|&i| i == 0xFF) - .expect("found 0xFF delim"); + roomuserid_joined + .raw_stream() + .ignore_err() + .ready_for_each(|(key, value)| { + let mut key = key.to_vec(); + iter_count = iter_count.saturating_add(1); + debug_info!(%iter_count); + let first_sep_index = key + .iter() + .position(|&i| i == 0xFF) + .expect("found 0xFF delim"); - if key - .iter() - .get(first_sep_index..=first_sep_index.saturating_add(1)) - .copied() - .collect_vec() - == vec![0xFF, 0xFF] - { - debug_warn!("Found bad key: {key:?}"); - roomuserid_joined.remove(&key)?; + if key + .iter() + .get(first_sep_index..=first_sep_index.saturating_add(1)) + .copied() + .collect_vec() + == vec![0xFF, 0xFF] + { + debug_warn!("Found bad key: {key:?}"); + roomuserid_joined.remove(&key); - key.remove(first_sep_index); - debug_warn!("Fixed key: {key:?}"); - roomuserid_joined.insert(&key, &value)?; - } - } + key.remove(first_sep_index); + debug_warn!("Fixed key: {key:?}"); + roomuserid_joined.insert(&key, value); + } + }) + .await; db.db.cleanup()?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; + db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); info!("Finished fixing"); Ok(()) @@ -795,69 +378,71 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) .rooms .metadata .iter_ids() - .filter_map(Result::ok) - .collect_vec(); + .map(ToOwned::to_owned) + .collect::>() + .await; - for room_id in room_ids.clone() { + for room_id in &room_ids { debug_info!("Fixing room {room_id}"); let users_in_room = services .rooms .state_cache - .room_members(&room_id) - .filter_map(Result::ok) - .collect_vec(); + .room_members(room_id) + .collect::>() + .await; let joined_members = users_in_room .iter() + .stream() .filter(|user_id| { services .rooms .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| membership.membership == MembershipState::Join) + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect_vec(); + .collect::>() + .await; let non_joined_members = users_in_room .iter() + .stream() .filter(|user_id| { services .rooms .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| { - membership.membership == MembershipState::Leave || membership.membership == MembershipState::Ban - }) + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect_vec(); + .collect::>() + .await; for user_id in joined_members { debug_info!("User is joined, marking as joined"); - services - .rooms - .state_cache - .mark_as_joined(user_id, &room_id)?; + services.rooms.state_cache.mark_as_joined(user_id, room_id); } for user_id in non_joined_members { debug_info!("User is left or banned, marking as left"); - services.rooms.state_cache.mark_as_left(user_id, &room_id)?; + services.rooms.state_cache.mark_as_left(user_id, room_id); } } - for room_id in room_ids { + for room_id in &room_ids { debug_info!( "Updating joined count for room {room_id} to fix servers in room after correcting membership states" ); - services.rooms.state_cache.update_joined_count(&room_id)?; + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; } db.db.cleanup()?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); info!("Finished fixing"); Ok(()) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 87f8f4925..f777901f6 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -288,8 +288,8 @@ impl Service { /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result> { - let mut keys = self.db.verify_keys_for(origin)?; + pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { + let mut keys = self.db.verify_keys_for(origin).await?; if origin == self.server_name() { keys.insert( format!("ed25519:{}", self.keypair().version()) @@ -304,8 +304,8 @@ impl Service { Ok(keys) } - pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { - self.db.signing_keys_for(origin) + pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { + self.db.signing_keys_for(origin).await } pub fn well_known_client(&self) -> &Option { &self.config.well_known.client } diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs deleted file mode 100644 index 30ac593b1..000000000 --- a/src/service/key_backups/data.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::{collections::BTreeMap, sync::Arc}; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{ - api::client::{ - backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - error::ErrorKind, - }, - serde::Raw, - OwnedRoomId, RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - backupid_algorithm: Arc, - backupid_etag: Arc, - backupkeyid_backup: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - backupid_algorithm: db["backupid_algorithm"].clone(), - backupid_etag: db["backupid_etag"].clone(), - backupkeyid_backup: db["backupkeyid_backup"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - let version = self.services.globals.next_count()?.to_string(); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.insert( - &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), - )?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version) - } - - pub(super) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.remove(&key)?; - self.backupid_etag.remove(&key)?; - - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_algorithm - .insert(&key, backup_metadata.json().get().as_bytes())?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version.to_owned()) - } - - pub(super) fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, _)| { - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) - }) - .transpose() - } - - pub(super) fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, value)| { - let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; - - Ok(( - version, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?, - )) - }) - .transpose() - } - - pub(super) fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm - .get(&key)? - .map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) - }) - } - - pub(super) fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .insert(&key, key_data.json().get().as_bytes())?; - - Ok(()) - } - - pub(super) fn count_keys(&self, user_id: &UserId, version: &str) -> Result { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - - Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) - } - - pub(super) fn get_etag(&self, user_id: &UserId, version: &str) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - Ok(utils::u64_from_bytes( - &self - .backupid_etag - .get(&key)? - .ok_or_else(|| Error::bad_database("Backup has no etag."))?, - ) - .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? - .to_string()) - } - - pub(super) fn get_all(&self, user_id: &UserId, version: &str) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - - let mut rooms = BTreeMap::::new(); - - for result in self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let room_id = RoomId::parse( - utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((room_id, session_id, key_data)) - }) { - let (room_id, session_id, key_data) = result?; - rooms - .entry(room_id) - .or_insert_with(|| RoomKeyBackup { - sessions: BTreeMap::new(), - }) - .sessions - .insert(session_id, key_data); - } - - Ok(rooms) - } - - pub(super) fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - Ok(self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((session_id, key_data)) - }) - .filter_map(Result::ok) - .collect()) - } - - pub(super) fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .get(&key)? - .map(|value| { - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")) - }) - .transpose() - } - - pub(super) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } -} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 65d3c065e..12712e793 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,93 +1,319 @@ -mod data; - use std::{collections::BTreeMap, sync::Arc}; -use conduit::Result; -use data::Data; +use conduit::{ + err, implement, utils, + utils::stream::{ReadyExt, TryIgnore}, + Err, Error, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::StreamExt; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; +use crate::{globals, Dep}; + pub struct Service { db: Data, + services: Services, +} + +struct Data { + backupid_algorithm: Arc, + backupid_etag: Arc, + backupkeyid_backup: Arc, +} + +struct Services { + globals: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + backupid_algorithm: args.db["backupid_algorithm"].clone(), + backupid_etag: args.db["backupid_etag"].clone(), + backupkeyid_backup: args.db["backupkeyid_backup"].clone(), + }, + services: Services { + globals: args.depend::("globals"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - self.db.create_backup(user_id, backup_metadata) - } +#[implement(Service)] +pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { + let version = self.services.globals.next_count()?.to_string(); - pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_backup(user_id, version) - } + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - pub fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw, - ) -> Result { - self.db.update_backup(user_id, version, backup_metadata) - } + self.db.backupid_algorithm.insert( + &key, + &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), + ); - pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - self.db.get_latest_backup_version(user_id) - } + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); - pub fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { - self.db.get_latest_backup(user_id) - } + Ok(version) +} - pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - self.db.get_backup(user_id, version) - } +#[implement(Service)] +pub async fn delete_backup(&self, user_id: &UserId, version: &str) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - pub fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()> { - self.db - .add_key(user_id, version, room_id, session_id, key_data) + self.db.backupid_algorithm.remove(&key); + self.db.backupid_etag.remove(&key); + + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn update_backup( + &self, user_id: &UserId, version: &str, backup_metadata: &Raw, +) -> Result { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); } - pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { self.db.count_keys(user_id, version) } + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { self.db.get_etag(user_id, version) } + self.db + .backupid_algorithm + .insert(&key, backup_metadata.json().get().as_bytes()); + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); - pub fn get_all(&self, user_id: &UserId, version: &str) -> Result> { - self.db.get_all(user_id, version) - } + Ok(version.to_owned()) +} - pub fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>> { - self.db.get_room(user_id, version, room_id) - } +#[implement(Service)] +pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - pub fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>> { - self.db.get_session(user_id, version, room_id, session_id) - } + self.db + .backupid_algorithm + .rev_raw_keys_from(&last_possible_key) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup versions found")))) + .and_then(|key| { + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + }) +} - pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_all_keys(user_id, version) - } +#[implement(Service)] +pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - self.db.delete_room_keys(user_id, version, room_id) - } + self.db + .backupid_algorithm + .rev_raw_stream_from(&last_possible_key) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup found")))) + .and_then(|(key, val)| { + let version = utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; + + let algorithm = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?; - pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { - self.db - .delete_room_key(user_id, version, room_id, session_id) + Ok((version, algorithm)) + }) +} + +#[implement(Service)] +pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result> { + let key = (user_id, version); + self.db + .backupid_algorithm + .qry(&key) + .await + .deserialized_json() +} + +#[implement(Service)] +pub async fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, +) -> Result<()> { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); + + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); + + self.db + .backupkeyid_backup + .insert(&key, key_data.json().get().as_bytes()); + + Ok(()) +} + +#[implement(Service)] +pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize { + let prefix = (user_id, version); + self.db + .backupkeyid_backup + .keys_raw_prefix(&prefix) + .count() + .await +} + +#[implement(Service)] +pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String { + let key = (user_id, version); + self.db + .backupid_etag + .qry(&key) + .await + .deserialized::() + .as_ref() + .map(ToString::to_string) + .expect("Backup has no etag.") +} + +#[implement(Service)] +pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap { + type KeyVal<'a> = ((Ignore, Ignore, &'a RoomId, &'a str), &'a [u8]); + + let mut rooms = BTreeMap::::new(); + let default = || RoomKeyBackup { + sessions: BTreeMap::new(), + }; + + let prefix = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .ready_for_each(|((_, _, room_id, session_id), value): KeyVal<'_>| { + let key_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); + rooms + .entry(room_id.into()) + .or_insert_with(default) + .sessions + .insert(session_id.into(), key_data); + }) + .await; + + rooms +} + +#[implement(Service)] +pub async fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, +) -> BTreeMap> { + type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), &'a [u8]); + + let prefix = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .map(|((.., session_id), value): KeyVal<'_>| { + let session_id = session_id.to_owned(); + let key_backup_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); + (session_id, key_backup_data) + }) + .collect() + .await +} + +#[implement(Service)] +pub async fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, +) -> Result> { + let key = (user_id, version, room_id, session_id); + + self.db + .backupkeyid_backup + .qry(&key) + .await + .deserialized_json() +} + +#[implement(Service)] +pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) { + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) { + let key = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) { + let key = (user_id, version, room_id, session_id); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; } diff --git a/src/service/manager.rs b/src/service/manager.rs index 42260bb30..21e0ed7c2 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -1,7 +1,7 @@ use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration}; use conduit::{debug, debug_warn, error, trace, utils::time, warn, Err, Error, Result, Server}; -use futures_util::FutureExt; +use futures::FutureExt; use tokio::{ sync::{Mutex, MutexGuard}, task::{JoinHandle, JoinSet}, diff --git a/src/service/media/data.rs b/src/service/media/data.rs index e5d6d20b1..29d562cc3 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use conduit::{ debug, debug_info, trace, - utils::{str_from_bytes, string_from_bytes}, + utils::{str_from_bytes, stream::TryIgnore, string_from_bytes, ReadyExt}, Err, Error, Result, }; use database::{Database, Map}; +use futures::StreamExt; use ruma::{api::client::error::ErrorKind, http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId}; use super::{preview::UrlPreviewData, thumbnail::Dim}; @@ -59,7 +60,7 @@ impl Data { .unwrap_or_default(), ); - self.mediaid_file.insert(&key, &[])?; + self.mediaid_file.insert(&key, &[]); if let Some(user) = user { let mut key: Vec = Vec::new(); @@ -68,13 +69,13 @@ impl Data { key.extend_from_slice(b"/"); key.extend_from_slice(mxc.media_id.as_bytes()); let user = user.as_bytes().to_vec(); - self.mediaid_user.insert(&key, &user)?; + self.mediaid_user.insert(&key, &user); } Ok(key) } - pub(super) fn delete_file_mxc(&self, mxc: &Mxc<'_>) -> Result<()> { + pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); let mut prefix: Vec = Vec::new(); @@ -85,25 +86,31 @@ impl Data { prefix.push(0xFF); trace!("MXC db prefix: {prefix:?}"); - for (key, _) in self.mediaid_file.scan_prefix(prefix.clone()) { - debug!("Deleting key: {:?}", key); - self.mediaid_file.remove(&key)?; - } - - for (key, value) in self.mediaid_user.scan_prefix(prefix.clone()) { - if key.starts_with(&prefix) { - let user = str_from_bytes(&value).unwrap_or_default(); - - debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}"); - self.mediaid_user.remove(&key)?; - } - } + self.mediaid_file + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| { + debug!("Deleting key: {:?}", key); + self.mediaid_file.remove(key); + }) + .await; - Ok(()) + self.mediaid_user + .raw_stream_prefix(&prefix) + .ignore_err() + .ready_for_each(|(key, val)| { + if key.starts_with(&prefix) { + let user = str_from_bytes(val).unwrap_or_default(); + debug_info!("Deleting key {key:?} which was uploaded by user {user}"); + + self.mediaid_user.remove(key); + } + }) + .await; } /// Searches for all files with the given MXC - pub(super) fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result>> { + pub(super) async fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result>> { debug!("MXC URI: {mxc}"); let mut prefix: Vec = Vec::new(); @@ -115,9 +122,10 @@ impl Data { let keys: Vec> = self .mediaid_file - .scan_prefix(prefix) - .map(|(key, _)| key) - .collect(); + .keys_prefix_raw(&prefix) + .ignore_err() + .collect() + .await; if keys.is_empty() { return Err!(Database("Failed to find any keys in database for `{mxc}`",)); @@ -128,7 +136,7 @@ impl Data { Ok(keys) } - pub(super) fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result { + pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result { let mut prefix: Vec = Vec::new(); prefix.extend_from_slice(b"mxc://"); prefix.extend_from_slice(mxc.server_name.as_bytes()); @@ -139,10 +147,13 @@ impl Data { prefix.extend_from_slice(&dim.height.to_be_bytes()); prefix.push(0xFF); - let (key, _) = self + let key = self .mediaid_file - .scan_prefix(prefix) + .raw_keys_prefix(&prefix) + .ignore_err() + .map(ToOwned::to_owned) .next() + .await .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; let mut parts = key.rsplit(|&b| b == 0xFF); @@ -177,28 +188,31 @@ impl Data { } /// Gets all the MXCs associated with a user - pub(super) fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec { - let user_id = user_id.as_bytes().to_vec(); - + pub(super) async fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec { self.mediaid_user - .iter() - .filter_map(|(key, user)| { - if *user == user_id { - let mxc_s = string_from_bytes(&key).ok()?; - Some(OwnedMxcUri::from(mxc_s)) - } else { - None - } - }) + .stream() + .ignore_err() + .ready_filter_map(|(key, user): (&str, &UserId)| (user == user_id).then(|| key.into())) .collect() + .await } /// Gets all the media keys in our database (this includes all the metadata /// associated with it such as width, height, content-type, etc) - pub(crate) fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } + pub(crate) async fn get_all_media_keys(&self) -> Vec> { + self.mediaid_file + .raw_keys() + .ignore_err() + .map(<[u8]>::to_vec) + .collect() + .await + } #[inline] - pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } + pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { + self.url_previews.remove(url.as_bytes()); + Ok(()) + } pub(super) fn set_url_preview( &self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration, @@ -233,11 +247,13 @@ impl Data { value.push(0xFF); value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); - self.url_previews.insert(url.as_bytes(), &value) + self.url_previews.insert(url.as_bytes(), &value); + + Ok(()) } - pub(super) fn get_url_preview(&self, url: &str) -> Option { - let values = self.url_previews.get(url.as_bytes()).ok()??; + pub(super) async fn get_url_preview(&self, url: &str) -> Result { + let values = self.url_previews.qry(url).await?; let mut values = values.split(|&b| b == 0xFF); @@ -291,7 +307,7 @@ impl Data { x => x, }; - Some(UrlPreviewData { + Ok(UrlPreviewData { title, description, image, diff --git a/src/service/media/migrations.rs b/src/service/media/migrations.rs index 9968d25b7..2d1b39f9f 100644 --- a/src/service/media/migrations.rs +++ b/src/service/media/migrations.rs @@ -7,7 +7,11 @@ use std::{ time::Instant, }; -use conduit::{debug, debug_info, debug_warn, error, info, warn, Config, Result}; +use conduit::{ + debug, debug_info, debug_warn, error, info, + utils::{stream::TryIgnore, ReadyExt}, + warn, Config, Result, +}; use crate::{globals, Services}; @@ -23,12 +27,17 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Move old media files to new names let mut changes = Vec::<(PathBuf, PathBuf)>::new(); - for (key, _) in mediaid_file.iter() { - let old = services.media.get_media_file_b64(&key); - let new = services.media.get_media_file_sha256(&key); - debug!(?key, ?old, ?new, num = changes.len(), "change"); - changes.push((old, new)); - } + mediaid_file + .raw_keys() + .ignore_err() + .ready_for_each(|key| { + let old = services.media.get_media_file_b64(key); + let new = services.media.get_media_file_sha256(key); + debug!(?key, ?old, ?new, num = changes.len(), "change"); + changes.push((old, new)); + }) + .await; + // move the file to the new location for (old_path, path) in changes { if old_path.exists() { @@ -41,11 +50,11 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services.globals.db.database_version()? == 14 && globals::migrations::DATABASE_VERSION == 13 { + if services.globals.db.database_version().await == 14 && globals::migrations::DATABASE_VERSION == 13 { services.globals.db.bump_database_version(13)?; } - db["global"].insert(b"feat_sha256_media", &[])?; + db["global"].insert(b"feat_sha256_media", &[]); info!("Finished applying sha256_media"); Ok(()) } @@ -71,7 +80,7 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> { .filter_map(|ent| ent.map_or(None, |ent| Some(ent.path().into_os_string()))) .collect(); - for key in media.db.get_all_media_keys() { + for key in media.db.get_all_media_keys().await { let new_path = media.get_media_file_sha256(&key).into_os_string(); let old_path = media.get_media_file_b64(&key).into_os_string(); if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await { @@ -112,8 +121,8 @@ async fn handle_media_check( "Media is missing at all paths. Removing from database..." ); - mediaid_file.remove(key)?; - mediaid_user.remove(key)?; + mediaid_file.remove(key); + mediaid_user.remove(key); } if config.media_compat_file_link && !old_exists && new_exists { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index d3765a176..c0b15726f 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -97,7 +97,7 @@ impl Service { /// Deletes a file in the database and from the media directory via an MXC pub async fn delete(&self, mxc: &Mxc<'_>) -> Result<()> { - if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc) { + if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc).await { for key in keys { trace!(?mxc, "MXC Key: {key:?}"); debug_info!(?mxc, "Deleting from filesystem"); @@ -107,7 +107,7 @@ impl Service { } debug_info!(?mxc, "Deleting from database"); - _ = self.db.delete_file_mxc(mxc); + self.db.delete_file_mxc(mxc).await; } Ok(()) @@ -120,7 +120,7 @@ impl Service { /// /// currently, this is only practical for local users pub async fn delete_from_user(&self, user: &UserId) -> Result { - let mxcs = self.db.get_all_user_mxcs(user); + let mxcs = self.db.get_all_user_mxcs(user).await; let mut deletion_count: usize = 0; for mxc in mxcs { @@ -150,7 +150,7 @@ impl Service { content_disposition, content_type, key, - }) = self.db.search_file_metadata(mxc, &Dim::default()) + }) = self.db.search_file_metadata(mxc, &Dim::default()).await { let mut content = Vec::new(); let path = self.get_media_file(&key); @@ -170,7 +170,7 @@ impl Service { /// Gets all the MXC URIs in our media database pub async fn get_all_mxcs(&self) -> Result> { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut mxcs = Vec::with_capacity(all_keys.len()); @@ -209,7 +209,7 @@ impl Service { pub async fn delete_all_remote_media_at_after_time( &self, time: SystemTime, before: bool, after: bool, yes_i_want_to_delete_local_media: bool, ) -> Result { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut remote_mxcs = Vec::with_capacity(all_keys.len()); for key in all_keys { @@ -343,9 +343,10 @@ impl Service { } #[inline] - pub fn get_metadata(&self, mxc: &Mxc<'_>) -> Option { + pub async fn get_metadata(&self, mxc: &Mxc<'_>) -> Option { self.db .search_file_metadata(mxc, &Dim::default()) + .await .map(|metadata| FileMeta { content_disposition: metadata.content_disposition, content_type: metadata.content_type, diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs index 5704075e5..6b1473838 100644 --- a/src/service/media/preview.rs +++ b/src/service/media/preview.rs @@ -71,16 +71,16 @@ pub async fn download_image(&self, url: &str) -> Result { #[implement(Service)] pub async fn get_url_preview(&self, url: &str) -> Result { - if let Some(preview) = self.db.get_url_preview(url) { + if let Ok(preview) = self.db.get_url_preview(url).await { return Ok(preview); } // ensure that only one request is made per URL let _request_lock = self.url_preview_mutex.lock(url).await; - match self.db.get_url_preview(url) { - Some(preview) => Ok(preview), - None => self.request_url_preview(url).await, + match self.db.get_url_preview(url).await { + Ok(preview) => Ok(preview), + Err(_) => self.request_url_preview(url).await, } } diff --git a/src/service/media/thumbnail.rs b/src/service/media/thumbnail.rs index 630f7b3b1..04ec03039 100644 --- a/src/service/media/thumbnail.rs +++ b/src/service/media/thumbnail.rs @@ -54,9 +54,9 @@ impl super::Service { // 0, 0 because that's the original file let dim = dim.normalized(); - if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim) { + if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim).await { self.get_thumbnail_saved(metadata).await - } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()) { + } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()).await { self.get_thumbnail_generate(mxc, &dim, metadata).await } else { Ok(None) diff --git a/src/service/mod.rs b/src/service/mod.rs index f588a5420..cb8bfcd95 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -19,6 +19,7 @@ pub mod resolver; pub mod rooms; pub mod sending; pub mod server_keys; +pub mod sync; pub mod transaction_ids; pub mod uiaa; pub mod updates; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index ec036b3d6..0c3f3d31d 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,7 +1,12 @@ use std::sync::Arc; -use conduit::{debug_warn, utils, Error, Result}; -use database::Map; +use conduit::{ + debug_warn, utils, + utils::{stream::TryIgnore, ReadyExt}, + Result, +}; +use database::{Deserialized, Map}; +use futures::Stream; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use super::Presence; @@ -31,39 +36,35 @@ impl Data { } } - pub fn get_presence(&self, user_id: &UserId) -> Result> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - - let key = presenceid_key(count, user_id); - self.presenceid_presence - .get(&key)? - .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { - Ok(( - count, - Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?, - )) - }) - .transpose() - } else { - Ok(None) - } + pub async fn get_presence(&self, user_id: &UserId) -> Result<(u64, PresenceEvent)> { + let count = self + .userid_presenceid + .qry(user_id) + .await + .deserialized::()?; + + let key = presenceid_key(count, user_id); + let bytes = self.presenceid_presence.qry(&key).await?; + let event = Presence::from_json_bytes(&bytes)? + .to_presence_event(user_id, &self.services.users) + .await; + + Ok((count, event)) } - pub(super) fn set_presence( + pub(super) async fn set_presence( &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { - let last_presence = self.get_presence(user_id)?; + let last_presence = self.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some(ref presence) => presence.1.content.presence != *presence_state, + Err(_) => true, + Ok(ref presence) => presence.1.content.presence != *presence_state, }; let status_msg_changed = match last_presence { - None => true, - Some(ref last_presence) => { + Err(_) => true, + Ok(ref last_presence) => { let old_msg = last_presence .1 .content @@ -79,8 +80,8 @@ impl Data { let now = utils::millis_since_unix_epoch(); let last_last_active_ts = match last_presence { - None => 0, - Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), + Err(_) => 0, + Ok((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), }; let last_active_ts = match last_active_ago { @@ -90,12 +91,7 @@ impl Data { // TODO: tighten for state flicker? if !status_msg_changed && !state_changed && last_active_ts < last_last_active_ts { - debug_warn!( - "presence spam {:?} last_active_ts:{:?} < {:?}", - user_id, - last_active_ts, - last_last_active_ts - ); + debug_warn!("presence spam {user_id:?} last_active_ts:{last_active_ts:?} < {last_last_active_ts:?}",); return Ok(()); } @@ -115,41 +111,42 @@ impl Data { let key = presenceid_key(count, user_id); self.presenceid_presence - .insert(&key, &presence.to_json_bytes()?)?; + .insert(&key, &presence.to_json_bytes()?); self.userid_presenceid - .insert(user_id.as_bytes(), &count.to_be_bytes())?; + .insert(user_id.as_bytes(), &count.to_be_bytes()); - if let Some((last_count, _)) = last_presence { + if let Ok((last_count, _)) = last_presence { let key = presenceid_key(last_count, user_id); - self.presenceid_presence.remove(&key)?; + self.presenceid_presence.remove(&key); } Ok(()) } - pub(super) fn remove_presence(&self, user_id: &UserId) -> Result<()> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - let key = presenceid_key(count, user_id); - self.presenceid_presence.remove(&key)?; - self.userid_presenceid.remove(user_id.as_bytes())?; - } + pub(super) async fn remove_presence(&self, user_id: &UserId) { + let Ok(count) = self + .userid_presenceid + .qry(user_id) + .await + .deserialized::() + else { + return; + }; - Ok(()) + let key = presenceid_key(count, user_id); + self.presenceid_presence.remove(&key); + self.userid_presenceid.remove(user_id.as_bytes()); } - pub fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a> { - Box::new( - self.presenceid_presence - .iter() - .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec)> { - let (count, user_id) = presenceid_parse(&key)?; - Ok((user_id.to_owned(), count, presence_bytes)) - }) - .filter(move |(_, count, _)| *count > since), - ) + pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { + self.presenceid_presence + .raw_stream() + .ignore_err() + .ready_filter_map(move |(key, presence_bytes)| { + let (count, user_id) = presenceid_parse(key).expect("invalid presenceid_parse"); + (count > since).then(|| (user_id.to_owned(), count, presence_bytes.to_vec())) + }) } } @@ -162,7 +159,7 @@ fn presenceid_key(count: u64, user_id: &UserId) -> Vec { fn presenceid_parse(key: &[u8]) -> Result<(u64, &UserId)> { let (count, user_id) = key.split_at(8); let user_id = user_id_from_bytes(user_id)?; - let count = utils::u64_from_bytes(count).unwrap(); + let count = utils::u64_from_u8(count); Ok((count, user_id)) } diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index a54a6d7c5..3b5c4caf4 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -4,8 +4,8 @@ mod presence; use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{checked, debug, error, Error, Result, Server}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{checked, debug, error, result::LogErr, Error, Result, Server}; +use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt}; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use tokio::{sync::Mutex, time::sleep}; @@ -58,7 +58,9 @@ impl crate::Service for Service { loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { - Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, + Some(user_id) = presence_timers.next() => { + self.process_presence_timer(&user_id).await.log_err().ok(); + }, event = receiver.recv_async() => match event { Err(_e) => return Ok(()), Ok((user_id, timeout)) => { @@ -82,28 +84,27 @@ impl crate::Service for Service { impl Service { /// Returns the latest presence event for the given user. #[inline] - pub fn get_presence(&self, user_id: &UserId) -> Result> { - if let Some((_, presence)) = self.db.get_presence(user_id)? { - Ok(Some(presence)) - } else { - Ok(None) - } + pub async fn get_presence(&self, user_id: &UserId) -> Result { + self.db + .get_presence(user_id) + .map_ok(|(_, presence)| presence) + .await } /// Pings the presence of the given user in the given room, setting the /// specified state. - pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { + pub async fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { const REFRESH_TIMEOUT: u64 = 60 * 25 * 1000; - let last_presence = self.db.get_presence(user_id)?; + let last_presence = self.db.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some((_, ref presence)) => presence.content.presence != *new_state, + Err(_) => true, + Ok((_, ref presence)) => presence.content.presence != *new_state, }; let last_last_active_ago = match last_presence { - None => 0_u64, - Some((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), + Err(_) => 0_u64, + Ok((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), }; if !state_changed && last_last_active_ago < REFRESH_TIMEOUT { @@ -111,17 +112,18 @@ impl Service { } let status_msg = match last_presence { - Some((_, ref presence)) => presence.content.status_msg.clone(), - None => Some(String::new()), + Ok((_, ref presence)) => presence.content.status_msg.clone(), + Err(_) => Some(String::new()), }; let last_active_ago = UInt::new(0); let currently_active = *new_state == PresenceState::Online; self.set_presence(user_id, new_state, Some(currently_active), last_active_ago, status_msg) + .await } /// Adds a presence event which will be saved until a new event replaces it. - pub fn set_presence( + pub async fn set_presence( &self, user_id: &UserId, state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { @@ -131,7 +133,8 @@ impl Service { }; self.db - .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)?; + .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg) + .await?; if self.timeout_remote_users || self.services.globals.user_is_local(user_id) { let timeout = match presence_state { @@ -154,28 +157,33 @@ impl Service { /// /// TODO: Why is this not used? #[allow(dead_code)] - pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) } + pub async fn remove_presence(&self, user_id: &UserId) { self.db.remove_presence(user_id).await } /// Returns the most recent presence updates that happened after the event /// with id `since`. #[inline] - pub fn presence_since(&self, since: u64) -> Box)> + '_> { + pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { self.db.presence_since(since) } - pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result { + #[inline] + pub async fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result { let presence = Presence::from_json_bytes(bytes)?; - presence.to_presence_event(user_id, &self.services.users) + let event = presence + .to_presence_event(user_id, &self.services.users) + .await; + + Ok(event) } - fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { + async fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; let mut status_msg = None; - let presence_event = self.get_presence(user_id)?; + let presence_event = self.get_presence(user_id).await; - if let Some(presence_event) = presence_event { + if let Ok(presence_event) = presence_event { presence_state = presence_event.content.presence; last_active_ago = presence_event.content.last_active_ago; status_msg = presence_event.content.status_msg; @@ -192,7 +200,8 @@ impl Service { ); if let Some(new_state) = new_state { - self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg) + .await?; } Ok(()) diff --git a/src/service/presence/presence.rs b/src/service/presence/presence.rs index 570008f29..0d5c226bf 100644 --- a/src/service/presence/presence.rs +++ b/src/service/presence/presence.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use conduit::{utils, Error, Result}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -42,7 +40,7 @@ impl Presence { } /// Creates a PresenceEvent from available data. - pub(super) fn to_presence_event(&self, user_id: &UserId, users: &Arc) -> Result { + pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent { let now = utils::millis_since_unix_epoch(); let last_active_ago = if self.currently_active { None @@ -50,16 +48,16 @@ impl Presence { Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts))) }; - Ok(PresenceEvent { + PresenceEvent { sender: user_id.to_owned(), content: PresenceEventContent { presence: self.state.clone(), status_msg: self.status_msg.clone(), currently_active: Some(self.currently_active), last_active_ago, - displayname: users.displayname(user_id)?, - avatar_url: users.avatar_url(user_id)?, + displayname: users.displayname(user_id).await.ok(), + avatar_url: users.avatar_url(user_id).await.ok(), }, - }) + } } } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs deleted file mode 100644 index f97343341..000000000 --- a/src/service/pusher/data.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::push::{set_pusher, Pusher}, - UserId, -}; - -pub(super) struct Data { - senderkey_pusher: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - senderkey_pusher: db["senderkey_pusher"].clone(), - } - } - - pub(super) fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - match pusher { - set_pusher::v3::PusherAction::Post(data) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); - self.senderkey_pusher - .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value"))?; - Ok(()) - }, - set_pusher::v3::PusherAction::Delete(ids) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(ids.pushkey.as_bytes()); - self.senderkey_pusher.remove(&key).map_err(Into::into) - }, - } - } - - pub(super) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - let mut senderkey = sender.as_bytes().to_vec(); - senderkey.push(0xFF); - senderkey.extend_from_slice(pushkey.as_bytes()); - - self.senderkey_pusher - .get(&senderkey)? - .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .transpose() - } - - pub(super) fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .collect() - } - - pub(super) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { - let mut parts = k.splitn(2, |&b| b == 0xFF); - let _senderkey = parts.next(); - let push_key = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; - let push_key_string = utils::string_from_bytes(push_key) - .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; - - Ok(push_key_string) - })) - } -} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index de87264c9..44ff1945c 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,9 +1,13 @@ -mod data; - use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduit::{debug_error, err, trace, utils::string_from_bytes, warn, Err, PduEvent, Result}; +use conduit::{ + debug_error, err, trace, + utils::{stream::TryIgnore, string_from_bytes}, + Err, PduEvent, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{Stream, StreamExt}; use ipaddress::IPAddress; use ruma::{ api::{ @@ -22,12 +26,11 @@ use ruma::{ uint, RoomId, UInt, UserId, }; -use self::data::Data; use crate::{client, globals, rooms, users, Dep}; pub struct Service { - services: Services, db: Data, + services: Services, } struct Services { @@ -38,9 +41,16 @@ struct Services { users: Dep, } +struct Data { + senderkey_pusher: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + db: Data { + senderkey_pusher: args.db["senderkey_pusher"].clone(), + }, services: Services { globals: args.depend::("globals"), client: args.depend::("client"), @@ -48,7 +58,6 @@ impl crate::Service for Service { state_cache: args.depend::("rooms::state_cache"), users: args.depend::("users"), }, - db: Data::new(args.db), })) } @@ -56,19 +65,52 @@ impl crate::Service for Service { } impl Service { - pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - self.db.set_pusher(sender, pusher) + pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) { + match pusher { + set_pusher::v3::PusherAction::Post(data) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); + self.db + .senderkey_pusher + .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value")); + }, + set_pusher::v3::PusherAction::Delete(ids) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(ids.pushkey.as_bytes()); + self.db.senderkey_pusher.remove(&key); + }, + } } - pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - self.db.get_pusher(sender, pushkey) + pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result { + let senderkey = (sender, pushkey); + self.db + .senderkey_pusher + .qry(&senderkey) + .await + .deserialized_json() } - pub fn get_pushers(&self, sender: &UserId) -> Result> { self.db.get_pushers(sender) } + pub async fn get_pushers(&self, sender: &UserId) -> Vec { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .stream_prefix(&prefix) + .ignore_err() + .map(|(_, val): (Ignore, &[u8])| serde_json::from_slice(val).expect("Invalid Pusher in db.")) + .collect() + .await + } - #[must_use] - pub fn get_pushkeys(&self, sender: &UserId) -> Box> + '_> { - self.db.get_pushkeys(sender) + pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, pushkey): (Ignore, &str)| pushkey) } #[tracing::instrument(skip(self, dest, request))] @@ -161,15 +203,18 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .and_then(|ev| { serde_json::from_str(ev.content.get()) - .map_err(|e| err!(Database("invalid m.room.power_levels event: {e:?}"))) + .map_err(|e| err!(Database(error!("invalid m.room.power_levels event: {e:?}")))) }) - .transpose()? .unwrap_or_default(); - for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? { + for action in self + .get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id) + .await? + { let n = match action { Action::Notify => true, Action::SetTweak(tweak) => { @@ -197,7 +242,7 @@ impl Service { } #[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")] - pub fn get_actions<'a>( + pub async fn get_actions<'a>( &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw, room_id: &RoomId, ) -> Result<&'a [Action]> { @@ -207,21 +252,27 @@ impl Service { notifications: power_levels.notifications.clone(), }; + let room_joined_count = self + .services + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(1) + .try_into() + .unwrap_or_else(|_| uint!(0)); + + let user_display_name = self + .services + .users + .displayname(user) + .await + .unwrap_or_else(|_| user.localpart().to_owned()); + let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), - member_count: UInt::try_from( - self.services - .state_cache - .room_joined_count(room_id)? - .unwrap_or(1), - ) - .unwrap_or_else(|_| uint!(0)), + member_count: room_joined_count, user_id: user.to_owned(), - user_display_name: self - .services - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), + user_display_name, power_levels: Some(power_levels), }; @@ -278,9 +329,14 @@ impl Service { notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); } - notifi.sender_display_name = self.services.users.displayname(&event.sender)?; + notifi.sender_display_name = self.services.users.displayname(&event.sender).await.ok(); - notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?; + notifi.room_name = self + .services + .state_accessor + .get_name(&event.room_id) + .await + .ok(); self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) .await?; diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 07d9a0fae..ea4b1100f 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -193,7 +193,7 @@ impl super::Service { .send() .await; - trace!("response: {:?}", response); + trace!("response: {response:?}"); if let Err(e) = &response { debug!("error: {e:?}"); return Ok(None); @@ -206,7 +206,7 @@ impl super::Service { } let text = response.text().await?; - trace!("response text: {:?}", text); + trace!("response text: {text:?}"); if text.len() >= 12288 { debug_warn!("response contains junk"); return Ok(None); @@ -225,7 +225,7 @@ impl super::Service { return Ok(None); } - debug_info!("{:?} found at {:?}", dest, m_server); + debug_info!("{dest:?} found at {m_server:?}"); Ok(Some(m_server.to_owned())) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs deleted file mode 100644 index efd2b5b76..000000000 --- a/src/service/rooms/alias/data.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - alias_userid: Arc, - alias_roomid: Arc, - aliasid_alias: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - alias_userid: db["alias_userid"].clone(), - alias_roomid: db["alias_roomid"].clone(), - aliasid_alias: db["aliasid_alias"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { - // Comes first as we don't want a stuck alias - self.alias_userid - .insert(alias.alias().as_bytes(), user_id.as_bytes())?; - - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xFF); - aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; - - Ok(()) - } - - pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id.to_vec(); - prefix.push(0xFF); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - - self.alias_roomid.remove(alias.alias().as_bytes())?; - - self.alias_userid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist or is invalid.")); - } - - Ok(()) - } - - pub(super) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_userid - .get(alias.alias().as_bytes())? - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("User ID in alias_userid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a + Send> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - })) - } - - pub(super) fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - Box::new( - self.alias_roomid - .iter() - .map(|(room_alias_bytes, room_id_bytes)| { - let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; - - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; - - Ok((room_id, room_alias_localpart)) - }), - ) - } -} diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index f2e01ab54..6b81a221a 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,19 +1,23 @@ -mod data; mod remote; use std::sync::Arc; -use conduit::{err, Error, Result}; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{Stream, StreamExt}; use ruma::{ api::client::error::ErrorKind, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, }, - OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId, }; -use self::data::Data; use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, Dep}; pub struct Service { @@ -21,6 +25,12 @@ pub struct Service { services: Services, } +struct Data { + alias_userid: Arc, + alias_roomid: Arc, + aliasid_alias: Arc, +} + struct Services { admin: Dep, appservice: Dep, @@ -32,7 +42,11 @@ struct Services { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + alias_userid: args.db["alias_userid"].clone(), + alias_roomid: args.db["alias_roomid"].clone(), + aliasid_alias: args.db["aliasid_alias"].clone(), + }, services: Services { admin: args.depend::("admin"), appservice: args.depend::("appservice"), @@ -50,25 +64,52 @@ impl Service { #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user { - Err(Error::BadRequest( + return Err(Error::BadRequest( ErrorKind::forbidden(), "Only the server user can set this alias", - )) - } else { - self.db.set_alias(alias, room_id, user_id) + )); } + + // Comes first as we don't want a stuck alias + self.db + .alias_userid + .insert(alias.alias().as_bytes(), user_id.as_bytes()); + + self.db + .alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes()); + + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xFF); + aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + self.db.aliasid_alias.insert(&aliasid, alias.as_bytes()); + + Ok(()) } #[tracing::instrument(skip(self))] pub async fn remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<()> { - if self.user_can_remove_alias(alias, user_id).await? { - self.db.remove_alias(alias) - } else { - Err(Error::BadRequest( - ErrorKind::forbidden(), - "User is not permitted to remove this alias.", - )) + if !self.user_can_remove_alias(alias, user_id).await? { + return Err!(Request(Forbidden("User is not permitted to remove this alias."))); } + + let alias = alias.alias(); + let Ok(room_id) = self.db.alias_roomid.qry(&alias).await else { + return Err!(Request(NotFound("Alias does not exist or is invalid."))); + }; + + let prefix = (&room_id, Interfix); + self.db + .aliasid_alias + .keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key: &[u8]| self.db.aliasid_alias.remove(&key)) + .await; + + self.db.alias_roomid.remove(alias.as_bytes()); + self.db.alias_userid.remove(alias.as_bytes()); + + Ok(()) } pub async fn resolve(&self, room: &RoomOrAliasId) -> Result { @@ -97,9 +138,9 @@ impl Service { return self.remote_resolve(room_alias, servers).await; } - let room_id: Option = match self.resolve_local_alias(room_alias)? { - Some(r) => Some(r), - None => self.resolve_appservice_alias(room_alias).await?, + let room_id: Option = match self.resolve_local_alias(room_alias).await { + Ok(r) => Some(r), + Err(_) => self.resolve_appservice_alias(room_alias).await?, }; room_id.map_or_else( @@ -109,46 +150,54 @@ impl Service { } #[tracing::instrument(skip(self), level = "debug")] - pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.db.resolve_local_alias(alias) + pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result { + self.db.alias_roomid.qry(alias.alias()).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a + Send> { - self.db.local_aliases_for_room(room_id) + pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .aliasid_alias + .stream_prefix(&prefix) + .ignore_err() + .map(|((Ignore, Ignore), alias): ((Ignore, Ignore), &RoomAliasId)| alias) } #[tracing::instrument(skip(self), level = "debug")] - pub fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - self.db.all_local_aliases() + pub fn all_local_aliases<'a>(&'a self) -> impl Stream + Send + 'a { + self.db + .alias_roomid + .stream() + .ignore_err() + .map(|(alias_localpart, room_id): (&str, &RoomId)| (room_id, alias_localpart)) } async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result { - let Some(room_id) = self.resolve_local_alias(alias)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found.")); - }; + let room_id = self + .resolve_local_alias(alias) + .await + .map_err(|_| err!(Request(NotFound("Alias not found."))))?; let server_user = &self.services.globals.server_user; // The creator of an alias can remove it if self - .db - .who_created_alias(alias)? - .is_some_and(|user| user == user_id) + .who_created_alias(alias).await + .is_ok_and(|user| user == user_id) // Server admins can remove any local alias - || self.services.admin.user_is_admin(user_id).await? + || self.services.admin.user_is_admin(user_id).await // Always allow the server service account to remove the alias, since there may not be an admin room || server_user == user_id { Ok(true) // Checking whether the user is able to change canonical aliases of the // room - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? + } else if let Ok(event) = self + .services + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "") + .await { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) @@ -157,10 +206,11 @@ impl Service { }) // If there is no power levels event, only the room creator can change // canonical aliases - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? + } else if let Ok(event) = self + .services + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await { Ok(event.sender == user_id) } else { @@ -168,6 +218,10 @@ impl Service { } } + async fn who_created_alias(&self, alias: &RoomAliasId) -> Result { + self.db.alias_userid.qry(alias.alias()).await.deserialized() + } + async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result> { use ruma::api::appservice::query::query_room_alias; @@ -185,10 +239,11 @@ impl Service { .await, Ok(Some(_opt_result)) ) { - return Ok(Some( - self.resolve_local_alias(room_alias)? - .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, - )); + return self + .resolve_local_alias(room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room does not exist.")))) + .map(Some); } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 6e7c78359..3d00374e7 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -24,7 +24,7 @@ impl Data { } } - pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { // Check RAM cache if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { return Ok(Some(Arc::clone(result))); @@ -33,17 +33,14 @@ impl Data { // We only save auth chains for single events in the db if key.len() == 1 { // Check DB cache - let chain = self - .shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::()) - .map(utils::u64_from_u8) - .collect::>() - }); + let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| { + chain + .chunks_exact(size_of::()) + .map(utils::u64_from_u8) + .collect::>() + }); - if let Some(chain) = chain { + if let Ok(chain) = chain { // Cache in RAM self.auth_chain_cache .lock() @@ -66,7 +63,7 @@ impl Data { .iter() .flat_map(|s| s.to_be_bytes().to_vec()) .collect::>(), - )?; + ); } // Cache in RAM diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index d0bc425fc..7bc239d7b 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,7 +5,8 @@ use std::{ sync::Arc, }; -use conduit::{debug, error, trace, validated, warn, Err, Result}; +use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; +use futures::{FutureExt, Stream, StreamExt}; use ruma::{EventId, RoomId}; use self::data::Data; @@ -38,7 +39,7 @@ impl crate::Service for Service { impl Service { pub async fn event_ids_iter<'a>( &'a self, room_id: &RoomId, starting_events_: Vec>, - ) -> Result> + 'a> { + ) -> Result> + Send + 'a> { let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); for starting_event in &starting_events_ { starting_events.push(starting_event); @@ -48,7 +49,13 @@ impl Service { .get_auth_chain(room_id, &starting_events) .await? .into_iter() - .filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok())) + .stream() + .filter_map(|sid| { + self.services + .short + .get_eventid_from_short(sid) + .map(Result::ok) + })) } #[tracing::instrument(skip_all, name = "auth_chain")] @@ -61,7 +68,8 @@ impl Service { for (i, &short) in self .services .short - .multi_get_or_create_shorteventid(starting_events)? + .multi_get_or_create_shorteventid(starting_events) + .await .iter() .enumerate() { @@ -85,7 +93,7 @@ impl Service { } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key)? { + if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); hits = hits.saturating_add(1); @@ -96,12 +104,12 @@ impl Service { let mut misses2: usize = 0; let mut chunk_cache = Vec::with_capacity(chunk.len()); for (sevent_id, event_id) in chunk { - if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id])? { + if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? { trace!(?event_id, "Found cache entry for event"); chunk_cache.extend(cached.iter().copied()); hits2 = hits2.saturating_add(1); } else { - let auth_chain = self.get_auth_chain_inner(room_id, event_id)?; + let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; self.cache_auth_chain(vec![sevent_id], &auth_chain)?; chunk_cache.extend(auth_chain.iter()); misses2 = misses2.saturating_add(1); @@ -143,15 +151,16 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { trace!(?event_id, "processing auth event"); - match self.services.timeline.get_pdu(&event_id) { - Ok(Some(pdu)) => { + match self.services.timeline.get_pdu(&event_id).await { + Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"), + Ok(pdu) => { if pdu.room_id != room_id { return Err!(Request(Forbidden( "auth event {event_id:?} for incorrect room {} which is not {}", @@ -160,7 +169,11 @@ impl Service { ))); } for auth_event in &pdu.auth_events { - let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?; + let sauthevent = self + .services + .short + .get_or_create_shorteventid(auth_event) + .await; if found.insert(sauthevent) { trace!(?event_id, ?auth_event, "adding auth event to processing queue"); @@ -168,20 +181,14 @@ impl Service { } } }, - Ok(None) => { - warn!(?event_id, "Could not find pdu mentioned in auth events"); - }, - Err(error) => { - error!(?event_id, ?error, "Could not load event in auth chain"); - }, } } Ok(found) } - pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { - self.db.get_cached_eventid_authchain(key) + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs deleted file mode 100644 index 713ee0576..000000000 --- a/src/service/rooms/directory/data.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{OwnedRoomId, RoomId}; - -pub(super) struct Data { - publicroomids: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - publicroomids: db["publicroomids"].clone(), - } - } - - pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.insert(room_id.as_bytes(), &[]) - } - - pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.remove(room_id.as_bytes()) - } - - pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) - } - - pub(super) fn public_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - })) - } -} diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 706e6c2e5..3585205d3 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,36 +1,44 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; - -use self::data::Data; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::{Ignore, Map}; +use futures::{Stream, StreamExt}; +use ruma::RoomId; pub struct Service { db: Data, } +struct Data { + publicroomids: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + publicroomids: args.db["publicroomids"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } +#[implement(Service)] +pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id.as_bytes(), &[]); } - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } +#[implement(Service)] +pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id.as_bytes()); } - #[tracing::instrument(skip(self), level = "debug")] - pub fn is_public_room(&self, room_id: &RoomId) -> Result { self.db.is_public_room(room_id) } +#[implement(Service)] +pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.qry(room_id).await.is_ok() } - #[tracing::instrument(skip(self), level = "debug")] - pub fn public_rooms(&self) -> impl Iterator> + '_ { self.db.public_rooms() } +#[implement(Service)] +pub fn public_rooms(&self) -> impl Stream + Send { + self.db + .publicroomids + .keys() + .ignore_err() + .map(|(room_id, _): (&RoomId, Ignore)| room_id) } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index bee986deb..07d6e4db9 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -3,17 +3,18 @@ mod parse_incoming_pdu; use std::{ collections::{hash_map, BTreeMap, HashMap, HashSet}, fmt::Write, - pin::Pin, sync::{Arc, RwLock as StdRwLock}, time::Instant, }; use conduit::{ - debug, debug_error, debug_info, err, error, info, pdu, trace, - utils::{math::continue_exponential_backoff_secs, MutexMap}, - warn, Error, PduEvent, Result, + debug, debug_error, debug_info, debug_warn, err, info, pdu, + result::LogErr, + trace, + utils::{math::continue_exponential_backoff_secs, IterStream, MutexMap}, + warn, Err, Error, PduEvent, Result, }; -use futures_util::Future; +use futures::{future, future::ready, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::{ client::error::ErrorKind, @@ -27,7 +28,7 @@ use ruma::{ }, int, serde::Base64, - state_res::{self, RoomVersion, StateMap}, + state_res::{self, EventTypeExt, RoomVersion, StateMap}, uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, ServerName, }; @@ -60,14 +61,6 @@ struct Services { type RoomMutexMap = MutexMap; type HandleTimeMap = HashMap; -// We use some AsyncRecursiveType hacks here so we can call async funtion -// recursively. -type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; -type AsyncRecursiveCanonicalJsonVec<'a> = - AsyncRecursiveType<'a, Vec<(Arc, Option>)>>; -type AsyncRecursiveCanonicalJsonResult<'a> = - AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -142,17 +135,17 @@ impl Service { pub_key_map: &'a RwLock>>, ) -> Result>> { // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? { + if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { return Ok(Some(pdu_id.to_vec())); } // 1.1 Check the server is in the room - if !self.services.metadata.exists(room_id)? { + if !self.services.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } // 1.2 Check if the room is disabled - if self.services.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "Federation of this room is currently disabled on this server.", @@ -160,7 +153,7 @@ impl Service { } // 1.3.1 Check room ACL on origin field/server - self.acl_check(origin, room_id)?; + self.acl_check(origin, room_id).await?; // 1.3.2 Check room ACL on sender's server name let sender: OwnedUserId = serde_json::from_value( @@ -172,26 +165,23 @@ impl Service { ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid"))?; - self.acl_check(sender.server_name(), room_id)?; + self.acl_check(sender.server_name(), room_id).await?; // Fetch create event let create_event = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await?; // Procure the room version let room_version_id = Self::get_room_version_id(&create_event)?; - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; let (incoming_pdu, val) = self .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map) + .boxed() .await?; Self::check_room_id(room_id, &incoming_pdu)?; @@ -235,7 +225,7 @@ impl Service { { Ok(()) => continue, Err(e) => { - warn!("Prev event {} failed: {}", prev_id, e); + warn!("Prev event {prev_id} failed: {e}"); match self .services .globals @@ -287,7 +277,7 @@ impl Service { create_event: &Arc, first_pdu_in_room: &Arc, prev_id: &EventId, ) -> Result<()> { // Check for disabled again because it might have changed - if self.services.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id).await { debug!( "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ event ID {event_id}" @@ -349,149 +339,153 @@ impl Service { } #[allow(clippy::too_many_arguments)] - fn handle_outlier_pdu<'a>( - &'a self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, + async fn handle_outlier_pdu<'a>( + &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, mut value: BTreeMap, auth_events_known: bool, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonResult<'a> { - Box::pin(async move { - // 1. Remove unsigned field - value.remove("unsigned"); + ) -> Result<(Arc, BTreeMap)> { + // 1. Remove unsigned field + value.remove("unsigned"); - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - let room_version_id = Self::get_room_version_id(create_event)?; + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let room_version_id = Self::get_room_version_id(create_event)?; - let guard = pub_key_map.read().await; - let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { - Err(e) => { - // Drop - warn!("Dropping bad event {}: {}", event_id, e,); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed")); - }, - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - debug_info!("Calculated hash does not match (redaction): {event_id}"); - let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")); - }; + let guard = pub_key_map.read().await; + let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { + Err(e) => { + // Drop + warn!("Dropping bad event {event_id}: {e}"); + return Err!(Request(InvalidParam("Signature verification failed"))); + }, + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + debug_info!("Calculated hash does not match (redaction): {event_id}"); + let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { + return Err!(Request(InvalidParam("Redaction failed"))); + }; - // Skip the PDU if it is redacted and we already have it as an outlier event - if self.services.timeline.get_pdu_json(event_id)?.is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event was redacted and we already knew about it", - )); - } + // Skip the PDU if it is redacted and we already have it as an outlier event + if self.services.timeline.get_pdu_json(event_id).await.is_ok() { + return Err!(Request(InvalidParam("Event was redacted and we already knew about it"))); + } - obj - }, - Ok(ruma::signatures::Verified::All) => value, - }; + obj + }, + Ok(ruma::signatures::Verified::All) => value, + }; - drop(guard); + drop(guard); - // Now that we have checked the signature and hashes we can add the eventID and - // convert to our PduEvent type - val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - let incoming_pdu = serde_json::from_value::( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; + // Now that we have checked the signature and hashes we can add the eventID and + // convert to our PduEvent type + val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + let incoming_pdu = serde_json::from_value::( + serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), + ) + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - Self::check_room_id(room_id, &incoming_pdu)?; + Self::check_room_id(room_id, &incoming_pdu)?; - if !auth_events_known { - // 4. fetch any missing auth events doing all checks listed here starting at 1. - // These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of - // the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often - debug!("Fetching auth events"); + if !auth_events_known { + // 4. fetch any missing auth events doing all checks listed here starting at 1. + // These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of + // the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + debug!("Fetching auth events"); + Box::pin( self.fetch_and_handle_outliers( origin, &incoming_pdu .auth_events .iter() .map(|x| Arc::from(&**x)) - .collect::>(), + .collect::>>(), create_event, room_id, &room_version_id, pub_key_map, - ) - .await; - } + ), + ) + .await; + } - // 6. Reject "due to auth events" if the event doesn't pass auth based on the - // auth events - debug!("Checking based on auth events"); - // Build map of auth events - let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); - for id in &incoming_pdu.auth_events { - let Some(auth_event) = self.services.timeline.get_pdu(id)? else { - warn!("Could not find auth event {}", id); - continue; - }; + // 6. Reject "due to auth events" if the event doesn't pass auth based on the + // auth events + debug!("Checking based on auth events"); + // Build map of auth events + let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); + for id in &incoming_pdu.auth_events { + let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { + warn!("Could not find auth event {id}"); + continue; + }; - Self::check_room_id(room_id, &auth_event)?; - - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times.", - )); - }, - } + Self::check_room_id(room_id, &auth_event)?; + + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", + )); + }, } + } - // The original create event must be in the auth events - if !matches!( - auth_events - .get(&(StateEventType::RoomCreate, String::new())) - .map(AsRef::as_ref), - Some(_) | None - ) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Incoming event refers to wrong create event.", - )); - } + // The original create event must be in the auth events + if !matches!( + auth_events + .get(&(StateEventType::RoomCreate, String::new())) + .map(AsRef::as_ref), + Some(_) | None + ) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); + } - if !state_res::event_auth::auth_check( - &Self::to_room_version(&room_version_id), - &incoming_pdu, - None::, // TODO: third party invite - |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), - ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed"))? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); - } + let state_fetch = |ty: &'static StateEventType, sk: &str| { + let key = ty.with_state_key(sk); + ready(auth_events.get(&key)) + }; - trace!("Validation successful."); + let auth_check = state_res::event_auth::auth_check( + &Self::to_room_version(&room_version_id), + &incoming_pdu, + None, // TODO: third party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - // 7. Persist the event as an outlier. - self.services - .outlier - .add_pdu_outlier(&incoming_pdu.event_id, &val)?; + if !auth_check { + return Err!(Request(Forbidden("Auth check failed"))); + } + + trace!("Validation successful."); + + // 7. Persist the event as an outlier. + self.services + .outlier + .add_pdu_outlier(&incoming_pdu.event_id, &val); - trace!("Added pdu as outlier."); + trace!("Added pdu as outlier."); - Ok((Arc::new(incoming_pdu), val)) - }) + Ok((Arc::new(incoming_pdu), val)) } pub async fn upgrade_outlier_to_timeline_pdu( @@ -499,16 +493,22 @@ impl Service { origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>>, ) -> Result>> { // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) { + if let Ok(pduid) = self + .services + .timeline + .get_pdu_id(&incoming_pdu.event_id) + .await + { return Ok(Some(pduid.to_vec())); } if self .services .pdu_metadata - .is_event_soft_failed(&incoming_pdu.event_id)? + .is_event_soft_failed(&incoming_pdu.event_id) + .await { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); + return Err!(Request(InvalidParam("Event has been soft failed"))); } debug!("Upgrading to timeline pdu"); @@ -545,57 +545,69 @@ impl Service { debug!("Performing auth check"); // 11. Check the auth of the event passes based on the state of the event - let check_result = state_res::event_auth::auth_check( + let state_fetch_state = &state_at_incoming_event; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = state_fetch_state.get(&shortstatekey)?; + self.services.timeline.get_pdu(event_id).await.ok() + }; + + let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, - None::, // TODO: third party invite - |k, s| { - self.services - .short - .get_shortstatekey(&k.to_string().into(), s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten()) - }, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?; + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - if !check_result { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Event has failed auth check with state at the event.", - )); + if !auth_check { + return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); } debug!("Gathering auth events"); - let auth_events = self.services.state.get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - )?; + let auth_events = self + .services + .state + .get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + ) + .await?; + + let state_fetch = |k: &'static StateEventType, s: &str| { + let key = k.with_state_key(s); + ready(auth_events.get(&key).cloned()) + }; + + let auth_check = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None, // third-party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; // Soft fail check before doing state res debug!("Performing soft-fail check"); let soft_fail = { use RoomVersionId::*; - !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::, |k, s| { - auth_events.get(&(k.clone(), s.to_owned())) - }) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? + !auth_check || incoming_pdu.kind == TimelineEventType::RoomRedaction && match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &incoming_pdu.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? } else { false } @@ -605,12 +617,11 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; if let Some(redact_id) = &content.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? } else { false } @@ -627,28 +638,52 @@ impl Service { // Now we calculate the set of extremities this room has after the incoming // event has been applied. We start with the previous extremities (aka leaves) trace!("Calculating extremities"); - let mut extremities = self.services.state.get_forward_extremities(room_id)?; - trace!("Calculated {} extremities", extremities.len()); + let mut extremities: HashSet<_> = self + .services + .state + .get_forward_extremities(room_id) + .map(ToOwned::to_owned) + .collect() + .await; // Remove any forward extremities that are referenced by this incoming event's // prev_events + trace!( + "Calculated {} extremities; checking against {} prev_events", + extremities.len(), + incoming_pdu.prev_events.len() + ); for prev_event in &incoming_pdu.prev_events { - extremities.remove(prev_event); + extremities.remove(&(**prev_event)); } // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); + let mut retained = HashSet::new(); + for id in &extremities { + if !self + .services + .pdu_metadata + .is_event_referenced(room_id, id) + .await + { + retained.insert(id.clone()); + } + } + + extremities.retain(|id| retained.contains(id)); debug!("Retained {} extremities. Compressing state", extremities.len()); - let state_ids_compressed = Arc::new( - state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - self.services - .state_compressor - .compress_state_event(*shortstatekey, id) - }) - .collect::>()?, - ); + + let mut state_ids_compressed = HashSet::new(); + for (shortstatekey, id) in &state_at_incoming_event { + state_ids_compressed.insert( + self.services + .state_compressor + .compress_state_event(*shortstatekey, id) + .await, + ); + } + + let state_ids_compressed = Arc::new(state_ids_compressed); if incoming_pdu.state_key.is_some() { debug!("Event is a state-event. Deriving new room state"); @@ -659,9 +694,11 @@ impl Service { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) + .await; - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); + let event_id = &incoming_pdu.event_id; + state_after.insert(shortstatekey, event_id.clone()); } let new_room_state = self @@ -673,7 +710,8 @@ impl Service { let (sstatehash, new, removed) = self .services .state_compressor - .save_state(room_id, new_room_state)?; + .save_state(room_id, new_room_state) + .await?; self.services .state @@ -698,16 +736,16 @@ impl Service { .await?; // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {:?}", incoming_pdu); + warn!("Event was soft failed: {incoming_pdu:?}"); self.services .pdu_metadata - .mark_event_soft_failed(&incoming_pdu.event_id)?; + .mark_event_soft_failed(&incoming_pdu.event_id); return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); } trace!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone()); + extremities.insert(incoming_pdu.event_id.clone().into()); // Now that the event has passed all auth it is added into the timeline. // We use the `state_at_event` instead of `state_after` so we accurately @@ -718,7 +756,7 @@ impl Service { .append_incoming_pdu( &incoming_pdu, val, - extremities.iter().map(|e| (**e).to_owned()).collect(), + extremities.into_iter().collect(), state_ids_compressed, soft_fail, &state_lock, @@ -742,8 +780,9 @@ impl Service { let current_sstatehash = self .services .state - .get_room_shortstatehash(room_id)? - .expect("every room has state"); + .get_room_shortstatehash(room_id) + .await + .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; let current_state_ids = self .services @@ -752,7 +791,6 @@ impl Service { .await?; let fork_states = [current_state_ids, incoming_state]; - let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { auth_chain_sets.push( @@ -760,62 +798,59 @@ impl Service { .auth_chain .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) .await? - .collect(), + .collect::>>() + .await, ); } debug!("Loading fork states"); - let fork_states: Vec<_> = fork_states + let fork_states: Vec>> = fork_states .into_iter() - .map(|map| { - map.into_iter() + .stream() + .then(|fork_state| { + fork_state + .into_iter() + .stream() .filter_map(|(k, id)| { self.services .short .get_statekey_from_short(k) - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) - .ok() + .map_ok_or_else(|_| None, move |(ty, st_key)| Some(((ty, st_key), id))) }) - .collect::>() + .collect() }) - .collect(); - - let lock = self.services.globals.stateres_mutex.lock(); + .collect() + .boxed() + .await; debug!("Resolving state"); - let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); + let lock = self.services.globals.stateres_mutex.lock(); - let state = match state_resolve { - Ok(new_state) => new_state, - Err(e) => { - error!("State resolution failed: {}", e); - return Err(Error::bad_database( - "State resolution failed, either an event could not be found or deserialization", - )); - }, - }; + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?; drop(lock); debug!("State resolution done. Compressing state"); - let new_room_state = state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - self.services - .state_compressor - .compress_state_event(shortstatekey, &event_id) - }) - .collect::>()?; + let mut new_room_state = HashSet::new(); + for ((event_type, state_key), event_id) in state { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) + .await; + + let compressed = self + .services + .state_compressor + .compress_state_event(shortstatekey, &event_id) + .await; + + new_room_state.insert(compressed); + } Ok(Arc::new(new_room_state)) } @@ -827,46 +862,47 @@ impl Service { &self, incoming_pdu: &Arc, ) -> Result>>> { let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = self + let Ok(prev_event_sstatehash) = self .services .state_accessor - .pdu_shortstatehash(prev_event)?; - - let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some( - self.services - .state_accessor - .state_full_ids(shortstatehash) - .await, - ) - } else { - None + .pdu_shortstatehash(prev_event) + .await + else { + return Ok(None); }; - if let Some(Ok(mut state)) = state { - debug!("Using cached state"); - let prev_pdu = self - .services - .timeline - .get_pdu(prev_event) - .ok() - .flatten() - .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; + let Ok(mut state) = self + .services + .state_accessor + .state_full_ids(prev_event_sstatehash) + .await + .log_err() + else { + return Ok(None); + }; - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; + debug!("Using cached state"); + let prev_pdu = self + .services + .timeline + .get_pdu(prev_event) + .await + .map_err(|e| err!(Database("Could not find prev event, but we know the state: {e:?}")))?; - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } + if let Some(state_key) = &prev_pdu.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .await; - return Ok(Some(state)); + state.insert(shortstatekey, Arc::from(prev_event)); + // Now it's the state after the pdu } - Ok(None) + debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result"); + + Ok(Some(state)) } #[tracing::instrument(skip_all, name = "state")] @@ -878,15 +914,16 @@ impl Service { let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else { + let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { okay = false; break; }; - let Ok(Some(sstatehash)) = self + let Ok(sstatehash) = self .services .state_accessor .pdu_shortstatehash(prev_eventid) + .await else { okay = false; break; @@ -901,20 +938,25 @@ impl Service { let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = self + let Ok(mut leaf_state) = self .services .state_accessor .state_full_ids(sstatehash) - .await?; + .await + else { + continue; + }; if let Some(state_key) = &prev_event.state_key { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) + .await; + + let event_id = &prev_event.event_id; + leaf_state.insert(shortstatekey, event_id.clone()); // Now it's the state after the pdu } @@ -922,13 +964,18 @@ impl Service { let mut starting_events = Vec::with_capacity(leaf_state.len()); for (k, id) in leaf_state { - if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) { + if let Ok((ty, st_key)) = self + .services + .short + .get_statekey_from_short(k) + .await + .log_err() + { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType state.insert((ty.to_string().into(), st_key), id.clone()); - } else { - warn!("Failed to get_statekey_from_short."); } + starting_events.push(id); } @@ -937,43 +984,40 @@ impl Service { .auth_chain .event_ids_iter(room_id, starting_events) .await? - .collect(), + .collect() + .await, ); fork_states.push(state); } let lock = self.services.globals.stateres_mutex.lock(); - let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); + + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed.")))); + drop(lock); - Ok(match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - Ok((shortstatekey, event_id)) - }) - .collect::>()?, - ), - Err(e) => { - warn!( - "State resolution on prev events failed, either an event could not be found or deserialization: {}", - e - ); - None - }, - }) + let Ok(new_state) = result else { + return Ok(None); + }; + + new_state + .iter() + .stream() + .then(|((event_type, state_key), event_id)| { + self.services + .short + .get_or_create_shortstatekey(event_type, state_key) + .map(move |shortstatekey| (shortstatekey, event_id.clone())) + }) + .collect() + .map(Some) + .map(Ok) + .await } /// Call /state_ids to find out what the state at this pdu is. We trust the @@ -985,7 +1029,7 @@ impl Service { pub_key_map: &RwLock>>, event_id: &EventId, ) -> Result>>> { debug!("Fetching state ids"); - match self + let res = self .services .sending .send_federation_request( @@ -996,61 +1040,57 @@ impl Service { }, ) .await - { - Ok(res) => { - debug!("Fetching state events"); - let collect = res - .pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(); - - let state_vec = self - .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) - .await; - - let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); - for (pdu, _) in state_vec { - let state_key = pdu - .state_key - .clone() - .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; + .inspect_err(|e| warn!("Fetching state for event failed: {e}"))?; + + debug!("Fetching state events"); + let collect = res + .pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(); + + let state_vec = self + .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) + .boxed() + .await; - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; + let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); + for (pdu, _) in state_vec { + let state_key = pdu + .state_key + .clone() + .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; - match state.entry(shortstatekey) { - hash_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::bad_database( - "State event's type and state_key combination exists multiple times.", - )) - }, - } - } + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) + .await; - // The original create event must still be in the state - let create_shortstatekey = self - .services - .short - .get_shortstatekey(&StateEventType::RoomCreate, "")? - .expect("Room exists"); + match state.entry(shortstatekey) { + hash_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::bad_database( + "State event's type and state_key combination exists multiple times.", + )) + }, + } + } - if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { - return Err(Error::bad_database("Incoming event refers to wrong create event.")); - } + // The original create event must still be in the state + let create_shortstatekey = self + .services + .short + .get_shortstatekey(&StateEventType::RoomCreate, "") + .await?; - Ok(Some(state)) - }, - Err(e) => { - warn!("Fetching state for event failed: {}", e); - Err(e) - }, + if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { + return Err!(Database("Incoming event refers to wrong create event.")); } + + Ok(Some(state)) } /// Find the event and auth it. Once the event is validated (steps 1 - 8) @@ -1062,191 +1102,196 @@ impl Service { /// b. Look at outlier pdu tree /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? - pub fn fetch_and_handle_outliers<'a>( - &'a self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, + pub async fn fetch_and_handle_outliers<'a>( + &self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonVec<'a> { - Box::pin(async move { - let back_off = |id| async { - match self + ) -> Vec<(Arc, Option>)> { + let back_off = |id| match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + }; + + let mut events_with_auth_events = Vec::with_capacity(events.len()); + for id in events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { + trace!("Found {id} in db"); + events_with_auth_events.push((id, Some(local_pdu), vec![])); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); + let mut events_all = HashSet::with_capacity(todo_auth_events.len()); + let mut i: u64 = 0; + while let Some(next_id) = todo_auth_events.pop() { + if let Some((time, tries)) = self .services .globals .bad_event_ratelimiter - .write() + .read() .expect("locked") - .entry(id) + .get(&*next_id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + info!("Backing off from {next_id}"); + continue; + } } - }; - let mut events_with_auth_events = Vec::with_capacity(events.len()); - for id in events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) { - trace!("Found {} in db", id); - events_with_auth_events.push((id, Some(local_pdu), vec![])); + if events_all.contains(&next_id) { continue; } - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); - let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i: u64 = 0; - while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&*next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - info!("Backing off from {next_id}"); - continue; - } - } - - if events_all.contains(&next_id) { - continue; - } - - i = i.saturating_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } + i = i.saturating_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } - if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) { - trace!("Found {} in db", next_id); - continue; - } + if self.services.timeline.get_pdu(&next_id).await.is_ok() { + trace!("Found {next_id} in db"); + continue; + } - debug!("Fetching {} over federation.", next_id); - match self - .services - .sending - .send_federation_request( - origin, - get_event::v1::Request { - event_id: (*next_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - debug!("Got {} over federation", next_id); - let Ok((calculated_event_id, value)) = - pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) - else { - back_off((*next_id).to_owned()).await; - continue; - }; - - if calculated_event_id != *next_id { - warn!( - "Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu - ); - } + debug!("Fetching {next_id} over federation."); + match self + .services + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + debug!("Got {next_id} over federation"); + let Ok((calculated_event_id, value)) = + pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) + else { + back_off((*next_id).to_owned()); + continue; + }; + + if calculated_event_id != *next_id { + warn!( + "Server didn't return event id we requested: requested: {next_id}, we got \ + {calculated_event_id}. Event: {:?}", + &res.pdu + ); + } - if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { - for auth_event in auth_events { - if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { - let a: Arc = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } + if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { + for auth_event in auth_events { + if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { + let a: Arc = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); } - } else { - warn!("Auth event list invalid"); } + } else { + warn!("Auth event list invalid"); + } - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - }, - Err(e) => { - debug_error!("Failed to fetch event {next_id}: {e}"); - back_off((*next_id).to_owned()).await; - }, - } + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + }, + Err(e) => { + debug_error!("Failed to fetch event {next_id}: {e}"); + back_off((*next_id).to_owned()); + }, } - events_with_auth_events.push((id, None, events_in_reverse_order)); } + events_with_auth_events.push((id, None, events_in_reverse_order)); + } - // We go through all the signatures we see on the PDUs and their unresolved - // dependencies and fetch the corresponding signing keys - self.services - .server_keys - .fetch_required_signing_keys( - events_with_auth_events - .iter() - .flat_map(|(_id, _local_pdu, events)| events) - .map(|(_event_id, event)| event), + // We go through all the signatures we see on the PDUs and their unresolved + // dependencies and fetch the corresponding signing keys + self.services + .server_keys + .fetch_required_signing_keys( + events_with_auth_events + .iter() + .flat_map(|(_id, _local_pdu, events)| events) + .map(|(_event_id, event)| event), + pub_key_map, + ) + .await + .unwrap_or_else(|e| { + warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}"); + }); + + let mut pdus = Vec::with_capacity(events_with_auth_events.len()); + for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Some(local_pdu) = local_pdu { + trace!("Found {id} in db"); + pdus.push((local_pdu.clone(), None)); + } + + for (next_id, value) in events_in_reverse_order.into_iter().rev() { + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(&*next_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!("Backing off from {next_id}"); + continue; + } + } + + match Box::pin(self.handle_outlier_pdu( + origin, + create_event, + &next_id, + room_id, + value.clone(), + true, pub_key_map, - ) + )) .await - .unwrap_or_else(|e| { - warn!("Could not fetch all signatures for PDUs from {}: {:?}", origin, e); - }); - - let mut pdus = Vec::with_capacity(events_with_auth_events.len()); - for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Some(local_pdu) = local_pdu { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - } - for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&**next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - debug!("Backing off from {next_id}"); - continue; + { + Ok((pdu, json)) => { + if next_id == *id { + pdus.push((pdu, Some(json))); } - } - - match self - .handle_outlier_pdu(origin, create_event, next_id, room_id, value.clone(), true, pub_key_map) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); - } - }, - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()).await; - }, - } + }, + Err(e) => { + warn!("Authentication of event {next_id} failed: {e:?}"); + back_off(next_id.into()); + }, } } - pdus - }) + } + pdus } #[allow(clippy::type_complexity)] @@ -1262,16 +1307,12 @@ impl Service { let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; let mut amount = 0; while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = self + if let Some((pdu, mut json_opt)) = self .fetch_and_handle_outliers( origin, &[prev_event_id.clone()], @@ -1280,28 +1321,29 @@ impl Service { room_version_id, pub_key_map, ) + .boxed() .await .pop() { Self::check_room_id(room_id, &pdu)?; - if amount > self.services.globals.max_fetch_prev_events() { - // Max limit reached - debug!( - "Max prev event limit reached! Limit: {}", - self.services.globals.max_fetch_prev_events() - ); + let limit = self.services.globals.max_fetch_prev_events(); + if amount > limit { + debug_warn!("Max prev event limit reached! Limit: {limit}"); graph.insert(prev_event_id.clone(), HashSet::new()); continue; } - if let Some(json) = json_opt.or_else(|| { - self.services + if json_opt.is_none() { + json_opt = self + .services .outlier .get_outlier_pdu_json(&prev_event_id) - .ok() - .flatten() - }) { + .await + .ok(); + } + + if let Some(json) = json_opt { if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { amount = amount.saturating_add(1); for prev_prev in &pdu.prev_events { @@ -1327,56 +1369,42 @@ impl Service { } } - let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { + let event_fetch = |event_id| { + let origin_server_ts = eventid_info + .get(&event_id) + .cloned() + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts); + // This return value is the key used for sorting events, // events are then sorted by power level, time, // and lexically by event_id. - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|e| { - error!("Error sorting prev events: {e}"); - Error::bad_database("Error sorting prev events") - })?; + future::ok((int!(0), MilliSecondsSinceUnixEpoch(origin_server_ts))) + }; + + let sorted = state_res::lexicographical_topological_sort(&graph, &event_fetch) + .await + .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?; Ok((sorted, eventid_info)) } /// Returns Ok if the acl allows the server #[tracing::instrument(skip_all)] - pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { - let acl_event = if let Some(acl) = - self.services - .state_accessor - .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? - { - trace!("ACL event found: {acl:?}"); - acl - } else { - trace!("No ACL event found"); + pub async fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + let Ok(acl_event_content) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomServerAcl, "") + .await + .map(|c: RoomServerAclEventContent| c) + .inspect(|acl| trace!("ACL content found: {acl:?}")) + .inspect_err(|e| trace!("No ACL content found: {e:?}")) + else { return Ok(()); }; - let acl_event_content: RoomServerAclEventContent = match serde_json::from_str(acl_event.content.get()) { - Ok(content) => { - trace!("Found ACL event contents: {content:?}"); - content - }, - Err(e) => { - warn!("Invalid ACL event: {e}"); - return Ok(()); - }, - }; - if acl_event_content.allow.is_empty() { warn!("Ignoring broken ACL event (allow key is empty)"); - // Ignore broken acl events return Ok(()); } @@ -1384,16 +1412,18 @@ impl Service { trace!("server {server_name} is allowed by ACL"); Ok(()) } else { - debug!("Server {} was denied by room ACL in {}", server_name, room_id); - Err(Error::BadRequest(ErrorKind::forbidden(), "Server was denied by room ACL")) + debug!("Server {server_name} was denied by room ACL in {room_id}"); + Err!(Request(Forbidden("Server was denied by room ACL"))) } } fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { if pdu.room_id != room_id { - warn!("Found event from room {} in room {}", pdu.room_id, room_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has wrong room id")); + return Err!(Request(InvalidParam( + warn!(pdu_event_id = ?pdu.event_id, pdu_room_id = ?pdu.room_id, ?room_id, "Found event from room in room") + ))); } + Ok(()) } @@ -1408,4 +1438,10 @@ impl Service { fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { RoomVersion::new(room_version_id).expect("room version is supported") } + + async fn event_exists(&self, event_id: Arc) -> bool { self.services.timeline.pdu_exists(&event_id).await } + + async fn event_fetch(&self, event_id: Arc) -> Option> { + self.services.timeline.get_pdu(&event_id).await.ok() + } } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index a7ffe1930..2de3e28ef 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -3,7 +3,9 @@ use ruma::{CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; impl super::Service { - pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + pub async fn parse_incoming_pdu( + &self, pdu: &RawJsonValue, + ) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { debug_warn!("Error parsing incoming event {pdu:#?}"); err!(BadServerResponse("Error parsing incoming event {e:?}")) @@ -14,7 +16,7 @@ impl super::Service { .and_then(|id| RoomId::parse(id.as_str()?).ok()) .ok_or(err!(Request(InvalidParam("Invalid room id in pdu"))))?; - let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else { + let Ok(room_version_id) = self.services.state.get_room_version(&room_id).await else { return Err!("Server is not in room {room_id}"); }; diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs deleted file mode 100644 index 073d45f56..000000000 --- a/src/service/rooms/lazy_loading/data.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, RoomId, UserId}; - -pub(super) struct Data { - lazyloadedids: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - lazyloadedids: db["lazyloadedids"].clone(), - } - } - - pub(super) fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) - } - - pub(super) fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for ll_id in confirmed_user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } - - Ok(()) - } - - pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } - - Ok(()) - } -} diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 0a9d4cf29..e0816d3f3 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,21 +1,26 @@ -mod data; - use std::{ collections::{HashMap, HashSet}, fmt::Write, sync::{Arc, Mutex}, }; -use conduit::{PduCount, Result}; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, Result, +}; +use database::{Interfix, Map}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use self::data::Data; - pub struct Service { - pub lazy_load_waiting: Mutex, + lazy_load_waiting: Mutex, db: Data, } +struct Data { + lazyloadedids: Arc, +} + type LazyLoadWaiting = HashMap; type LazyLoadWaitingKey = (OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount); type LazyLoadWaitingVal = HashSet; @@ -23,8 +28,10 @@ type LazyLoadWaitingVal = HashSet; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - lazy_load_waiting: Mutex::new(HashMap::new()), - db: Data::new(args.db), + lazy_load_waiting: LazyLoadWaiting::new().into(), + db: Data { + lazyloadedids: args.db["lazyloadedids"].clone(), + }, })) } @@ -40,47 +47,60 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result { - self.db - .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +#[inline] +pub async fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, +) -> bool { + let key = (user_id, device_id, room_id, ll_user); + self.db.lazyloadedids.qry(&key).await.is_ok() +} - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_mark_sent( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, - count: PduCount, - ) { - self.lazy_load_waiting - .lock() - .expect("locked") - .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_mark_sent( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, count: PduCount, +) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count); - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, - ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - self.db - .lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; - } else { - // Ignore - } + self.lazy_load_waiting + .lock() + .expect("locked") + .insert(key, lazy_load); +} - Ok(()) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since); + + let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else { + return; + }; - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - self.db.lazy_load_reset(user_id, device_id, room_id) + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); + + for ll_id in &user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(ll_id.as_bytes()); + self.db.lazyloadedids.insert(&key, &[]); } } + +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) { + let prefix = (user_id, device_id, room_id, Interfix); + self.db + .lazyloadedids + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.lazyloadedids.remove(key)) + .await; +} diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs deleted file mode 100644 index efe681b1b..000000000 --- a/src/service/rooms/metadata/data.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use conduit::{error, utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, RoomId}; - -use crate::{rooms, Dep}; - -pub(super) struct Data { - disabledroomids: Arc, - bannedroomids: Arc, - roomid_shortroomid: Arc, - pduid_pdu: Arc, - services: Services, -} - -struct Services { - short: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - disabledroomids: db["disabledroomids"].clone(), - bannedroomids: db["bannedroomids"].clone(), - roomid_shortroomid: db["roomid_shortroomid"].clone(), - pduid_pdu: db["pduid_pdu"].clone(), - services: Services { - short: args.depend::("rooms::short"), - }, - } - } - - pub(super) fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match self.services.short.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; - - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } - - pub(super) fn iter_ids<'a>(&'a self) -> Box> + 'a> { - Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - })) - } - - #[inline] - pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - if disabled { - self.disabledroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.disabledroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - #[inline] - pub(super) fn is_banned(&self, room_id: &RoomId) -> Result { - Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { - if banned { - self.bannedroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.bannedroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - pub(super) fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.bannedroomids.iter().map( - |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|e| { - error!("Invalid room_id bytes in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids.") - })? - .try_into() - .map_err(|e| { - error!("Invalid room_id in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids") - })?; - - Ok(room_id) - }, - )) - } -} diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 7415c53b7..5d4a47c71 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,51 +1,92 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::Map; +use futures::{Stream, StreamExt}; +use ruma::RoomId; -use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + disabledroomids: Arc, + bannedroomids: Arc, + roomid_shortroomid: Arc, + pduid_pdu: Arc, +} + +struct Services { + short: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + disabledroomids: args.db["disabledroomids"].clone(), + bannedroomids: args.db["bannedroomids"].clone(), + roomid_shortroomid: args.db["roomid_shortroomid"].clone(), + pduid_pdu: args.db["pduid_pdu"].clone(), + }, + services: Services { + short: args.depend::("rooms::short"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Checks if a room exists. - #[inline] - pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } +#[implement(Service)] +pub async fn exists(&self, room_id: &RoomId) -> bool { + let Ok(prefix) = self.services.short.get_shortroomid(room_id).await else { + return false; + }; + + // Look for PDUs in that room. + self.db + .pduid_pdu + .keys_raw_prefix(&prefix) + .ignore_err() + .next() + .await + .is_some() +} - #[must_use] - pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } +#[implement(Service)] +pub fn iter_ids(&self) -> impl Stream + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() } - #[inline] - pub fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } +#[implement(Service)] +#[inline] +pub fn disable_room(&self, room_id: &RoomId, disabled: bool) { + if disabled { + self.db.disabledroomids.insert(room_id.as_bytes(), &[]); + } else { + self.db.disabledroomids.remove(room_id.as_bytes()); + } +} - #[inline] - pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - self.db.disable_room(room_id, disabled) +#[implement(Service)] +#[inline] +pub fn ban_room(&self, room_id: &RoomId, banned: bool) { + if banned { + self.db.bannedroomids.insert(room_id.as_bytes(), &[]); + } else { + self.db.bannedroomids.remove(room_id.as_bytes()); } +} - #[inline] - pub fn is_banned(&self, room_id: &RoomId) -> Result { self.db.is_banned(room_id) } +#[implement(Service)] +pub fn list_banned_rooms(&self) -> impl Stream + Send + '_ { self.db.bannedroomids.keys().ignore_err() } - #[inline] - pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } +#[implement(Service)] +#[inline] +pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.qry(room_id).await.is_ok() } - #[inline] - #[must_use] - pub fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - self.db.list_banned_rooms() - } -} +#[implement(Service)] +#[inline] +pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.qry(room_id).await.is_ok() } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs deleted file mode 100644 index aa804721b..000000000 --- a/src/service/rooms/outlier/data.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::sync::Arc; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{CanonicalJsonObject, EventId}; - -use crate::PduEvent; - -pub(super) struct Data { - eventid_outlierpdu: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - eventid_outlierpdu: db["eventid_outlierpdu"].clone(), - } - } - - pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) - } -} diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 22bd2092a..277b59826 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,9 +1,7 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Deserialized, Map}; use ruma::{CanonicalJsonObject, EventId}; use crate::PduEvent; @@ -12,31 +10,48 @@ pub struct Service { db: Data, } +struct Data { + eventid_outlierpdu: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + eventid_outlierpdu: args.db["eventid_outlierpdu"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Returns the pdu from the outlier tree. - pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_outlier_pdu_json(event_id) - } +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result { + self.db + .eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() +} - /// Returns the pdu from the outlier tree. - /// - /// TODO: use this? - #[allow(dead_code)] - pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu(event_id) } +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result { + self.db + .eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() +} - /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu), level = "debug")] - pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.db.add_pdu_outlier(event_id, pdu) - } +/// Append the PDU as an outlier. +#[implement(Service)] +#[tracing::instrument(skip(self, pdu), level = "debug")] +pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) { + self.db.eventid_outlierpdu.insert( + event_id.as_bytes(), + &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), + ); } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index d1649da81..f23234752 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,7 +1,13 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, PduCount, PduEvent, Result}; +use conduit::{ + result::LogErr, + utils, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, PduEvent, +}; use database::Map; +use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId, UserId}; use crate::{rooms, Dep}; @@ -17,8 +23,7 @@ struct Services { timeline: Dep, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box + 'a>; +pub(super) type PdusIterItem = (PduCount, PduEvent); impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { @@ -33,19 +38,17 @@ impl Data { } } - pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> { + pub(super) fn add_relation(&self, from: u64, to: u64) { let mut key = to.to_be_bytes().to_vec(); key.extend_from_slice(&from.to_be_bytes()); - self.tofrom_relation.insert(&key, &[])?; - Ok(()) + self.tofrom_relation.insert(&key, &[]); } pub(super) fn relations_until<'a>( &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, - ) -> Result> { + ) -> impl Stream + Send + 'a + '_ { let prefix = target.to_be_bytes().to_vec(); let mut current = prefix.clone(); - let count_raw = match until { PduCount::Normal(x) => x.saturating_sub(1), PduCount::Backfilled(x) => { @@ -55,53 +58,42 @@ impl Data { }; current.extend_from_slice(&count_raw.to_be_bytes()); - Ok(Box::new( - self.tofrom_relation - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(tofrom, _data)| { - let from = utils::u64_from_bytes(&tofrom[(size_of::())..]) - .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; - - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); - - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((PduCount::Normal(from), pdu)) - }), - )) + self.tofrom_relation + .rev_raw_keys_from(¤t) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|to_from| utils::u64_from_u8(&to_from[(size_of::())..])) + .filter_map(move |from| async move { + let mut pduid = shortroomid.to_be_bytes().to_vec(); + pduid.extend_from_slice(&from.to_be_bytes()); + let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?; + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + Some((PduCount::Normal(from), pdu)) + }) } - pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; + self.referencedevents.insert(&key, &[]); } - - Ok(()) } - pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) + pub(super) async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + let key = (room_id, event_id); + self.referencedevents.qry(&key).await.is_ok() } - pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) + pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { + self.softfailedeventids.insert(event_id.as_bytes(), &[]); } - pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) + pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.softfailedeventids.qry(event_id).await.is_ok() } } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index d9eaf3244..dbaebfbf3 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,8 +1,8 @@ mod data; - use std::sync::Arc; -use conduit::{PduCount, PduEvent, Result}; +use conduit::{utils::stream::IterStream, PduCount, Result}; +use futures::StreamExt; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, @@ -10,7 +10,7 @@ use ruma::{ }; use serde::Deserialize; -use self::data::Data; +use self::data::{Data, PdusIterItem}; use crate::{rooms, Dep}; pub struct Service { @@ -51,21 +51,19 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self, from, to), level = "debug")] - pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { + pub fn add_relation(&self, from: PduCount, to: PduCount) { match (from, to) { (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), _ => { // TODO: Relations with backfilled pdus - - Ok(()) }, } } #[allow(clippy::too_many_arguments)] - pub fn paginate_relations_with_filter( - &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option, - filter_rel_type: &Option, from: &Option, to: &Option, limit: &Option, + pub async fn paginate_relations_with_filter( + &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option, + filter_rel_type: Option, from: Option<&String>, to: Option<&String>, limit: Option, recurse: bool, dir: Direction, ) -> Result { let from = match from { @@ -76,7 +74,7 @@ impl Service { }, }; - let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + let to = to.and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 let limit = limit @@ -92,30 +90,32 @@ impl Service { 1 }; - let relations_until = &self.relations_until(sender_user, room_id, target, from, depth)?; - let events: Vec<_> = relations_until // TODO: should be relations_after - .iter() - .filter(|(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::(pdu.content.get()) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &content.relates_to.rel_type == r) - } else { - false - } - }) - .take(limit) - .filter(|(_, pdu)| { - self.services - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to` - .collect(); + let relations_until: Vec = self + .relations_until(sender_user, room_id, target, from, depth) + .await?; + + // TODO: should be relations_after + let events: Vec<_> = relations_until + .into_iter() + .filter(move |(_, pdu): &PdusIterItem| { + if !filter_event_type.as_ref().map_or(true, |t| pdu.kind == *t) { + return false; + } + + let Ok(content) = serde_json::from_str::(pdu.content.get()) else { + return false; + }; + + filter_rel_type + .as_ref() + .map_or(true, |r| *r == content.relates_to.rel_type) + }) + .take(limit) + .take_while(|(k, _)| Some(*k) != to) + .stream() + .filter_map(|item| self.visibility_filter(sender_user, item)) + .collect() + .await; let next_token = events.last().map(|(count, _)| count).copied(); @@ -125,9 +125,9 @@ impl Service { .map(|(_, pdu)| pdu.to_message_like_event()) .collect(), Direction::Backward => events - .into_iter() - .rev() // relations are always most recent first - .map(|(_, pdu)| pdu.to_message_like_event()) + .into_iter() + .rev() // relations are always most recent first + .map(|(_, pdu)| pdu.to_message_like_event()) .collect(), }; @@ -135,68 +135,85 @@ impl Service { chunk: events_chunk, next_batch: next_token.map(|t| t.stringify()), prev_batch: Some(from.stringify()), - recursion_depth: if recurse { - Some(depth.into()) - } else { - None - }, + recursion_depth: recurse.then_some(depth.into()), }) } - pub fn relations_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, - ) -> Result> { - let room_id = self.services.short.get_or_create_shortroomid(room_id)?; - #[allow(unknown_lints)] - #[allow(clippy::manual_unwrap_or_default)] - let target = match self.services.timeline.get_pdu_count(target)? { - Some(PduCount::Normal(c)) => c, + async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option { + let (_, pdu) = &item; + + self.services + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) + } + + pub async fn relations_until( + &self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, max_depth: u8, + ) -> Result> { + let room_id = self.services.short.get_or_create_shortroomid(room_id).await; + + let target = match self.services.timeline.get_pdu_count(target).await { + Ok(PduCount::Normal(c)) => c, // TODO: Support backfilled relations _ => 0, // This will result in an empty iterator }; - self.db + let mut pdus: Vec = self + .db .relations_until(user_id, room_id, target, until) - .map(|mut relations| { - let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect(); - let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect(); - - while let Some(stack_pdu) = stack.pop() { - let target = match stack_pdu.0 .0 { - PduCount::Normal(c) => c, - // TODO: Support backfilled relations - PduCount::Backfilled(_) => 0, // This will result in an empty iterator - }; - - if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { - for relation in relations.flatten() { - if stack_pdu.1 < max_depth { - stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); - } - - pdus.push(relation); - } - } + .collect() + .await; + + let mut stack: Vec<_> = pdus.clone().into_iter().map(|pdu| (pdu, 1)).collect(); + + while let Some(stack_pdu) = stack.pop() { + let target = match stack_pdu.0 .0 { + PduCount::Normal(c) => c, + // TODO: Support backfilled relations + PduCount::Backfilled(_) => 0, // This will result in an empty iterator + }; + + let relations: Vec = self + .db + .relations_until(user_id, room_id, target, until) + .collect() + .await; + + for relation in relations { + if stack_pdu.1 < max_depth { + stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); } - pdus.sort_by(|a, b| a.0.cmp(&b.0)); - pdus - }) + pdus.push(relation); + } + } + + pdus.sort_by(|a, b| a.0.cmp(&b.0)); + + Ok(pdus) } + #[inline] #[tracing::instrument(skip_all, level = "debug")] - pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { - self.db.mark_as_referenced(room_id, event_ids) + pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { + self.db.mark_as_referenced(room_id, event_ids); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - self.db.is_event_referenced(room_id, event_id) + pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + self.db.is_event_referenced(room_id, event_id).await } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } + pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.db.is_event_soft_failed(event_id) } + pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.db.is_event_soft_failed(event_id).await + } } diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 0c156df38..a2c0fabca 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,10 +1,18 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, RoomId, UserId}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; +use ruma::{ + events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, + serde::Raw, + CanonicalJsonObject, OwnedUserId, RoomId, UserId, +}; -use super::AnySyncEphemeralRoomEventIter; use crate::{globals, Dep}; pub(super) struct Data { @@ -18,6 +26,8 @@ struct Services { globals: Dep, } +pub(super) type ReceiptItem = (OwnedUserId, u64, Raw); + impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { let db = &args.db; @@ -31,7 +41,9 @@ impl Data { } } - pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { + pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + type KeyVal<'a> = (&'a RoomId, u64, &'a UserId); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -39,108 +51,90 @@ impl Data { last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } + self.readreceiptid_readreceipt + .rev_keys_from_raw(&last_possible_key) + .ignore_err() + .ready_take_while(|(r, ..): &KeyVal<'_>| *r == room_id) + .ready_filter_map(|(r, c, u): KeyVal<'_>| (u == user_id).then_some((r, c, u))) + .ready_for_each(|old: KeyVal<'_>| { + // This is the old room_latest + self.readreceiptid_readreceipt.del(&old); + }) + .await; let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); room_latest_id.push(0xFF); room_latest_id.extend_from_slice(user_id.as_bytes()); self.readreceiptid_readreceipt.insert( &room_latest_id, &serde_json::to_vec(event).expect("EduEvent::to_string always works"), - )?; - - Ok(()) + ); } - pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> { + pub(super) fn readreceipts_since<'a>( + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream + Send + 'a { + let after_since = since.saturating_add(1); // +1 so we don't send the event at since + let first_possible_edu = (room_id, after_since); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); let prefix2 = prefix.clone(); - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since - - Box::new( - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count_offset = prefix.len().saturating_add(size_of::()); - let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id_offset = count_offset.saturating_add(1); - let user_id = UserId::parse( - utils::string_from_bytes(&k[user_id_offset..]) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, - ) + self.readreceiptid_readreceipt + .stream_raw_from(&first_possible_edu) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count_offset = prefix.len().saturating_add(size_of::()); + let user_id_offset = count_offset.saturating_add(1); + + let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + + let user_id_str = utils::string_from_bytes(&k[user_id_offset..]) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?; + + let user_id = UserId::parse(user_id_str) .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - let mut json = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - json.remove("room_id"); - - Ok(( - user_id, - count, - Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")), - )) - }), - ) - } + let mut json = serde_json::from_slice::(v) + .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); + json.remove("room_id"); - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; + let event = Raw::from_json(serde_json::value::to_raw_value(&json)?); - self.roomuserid_lastprivatereadupdate - .insert(&key, &self.services.globals.next_count()?.to_be_bytes()) + Ok((user_id, count, event)) + }) + .ignore_err() } - pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some( - utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, - )) - }) + .insert(&key, &count.to_be_bytes()); + + self.roomuserid_lastprivatereadupdate + .insert(&key, &self.services.globals.next_count().unwrap().to_be_bytes()); } - pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); + pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.roomuserid_privateread.qry(&key).await.deserialized() + } - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) + pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.roomuserid_lastprivatereadupdate + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } } diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index da11e2a0f..ec34361e0 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -3,16 +3,17 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; use conduit::{debug, Result}; -use data::Data; +use futures::Stream; use ruma::{ events::{ receipt::{ReceiptEvent, ReceiptEventContent}, - AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent, + SyncEphemeralRoomEvent, }, serde::Raw, - OwnedUserId, RoomId, UserId, + RoomId, UserId, }; +use self::data::{Data, ReceiptItem}; use crate::{sending, Dep}; pub struct Service { @@ -24,9 +25,6 @@ struct Services { sending: Dep, } -type AnySyncEphemeralRoomEventIter<'a> = - Box)>> + 'a>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -42,44 +40,53 @@ impl crate::Service for Service { impl Service { /// Replaces the previous read receipt. - pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event)?; - self.services.sending.flush_room(room_id)?; - - Ok(()) + pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + self.db.readreceipt_update(user_id, room_id, event).await; + self.services + .sending + .flush_room(room_id) + .await + .expect("room flush failed"); } /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] pub fn readreceipts_since<'a>( - &'a self, room_id: &RoomId, since: u64, - ) -> impl Iterator)>> + 'a { + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream + Send + 'a { self.db.readreceipts_since(room_id, since) } /// Sets a private read marker at `count`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - self.db.private_read_set(room_id, user_id, count) + pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { + self.db.private_read_set(room_id, user_id, count); } /// Returns the private read marker. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.private_read_get(room_id, user_id) + pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { + self.db.private_read_get(room_id, user_id).await } /// Returns the count of the last typing update in this room. - pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_privateread_update(user_id, room_id) + #[inline] + pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.last_privateread_update(user_id, room_id).await } } #[must_use] -pub fn pack_receipts(receipts: AnySyncEphemeralRoomEventIter<'_>) -> Raw> { +pub fn pack_receipts(receipts: I) -> Raw> +where + I: Iterator, +{ let mut json = BTreeMap::new(); - for (_user, _count, value) in receipts.flatten() { + for (_, _, value) in receipts { let receipt = serde_json::from_str::>(value.json().get()); if let Ok(value) = receipt { for (event, receipt) in value.content { diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index a0086095b..de98beeeb 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,13 +1,12 @@ use std::sync::Arc; -use conduit::{utils, Result}; +use conduit::utils::{set, stream::TryIgnore, IterStream, ReadyExt}; use database::Map; +use futures::StreamExt; use ruma::RoomId; use crate::{rooms, Dep}; -type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; - pub(super) struct Data { tokenids: Arc, services: Services, @@ -28,7 +27,7 @@ impl Data { } } - pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { let batch = tokenize(message_body) .map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); @@ -39,11 +38,10 @@ impl Data { }) .collect::>(); - self.tokenids - .insert_batch(batch.iter().map(database::KeyVal::from)) + self.tokenids.insert_batch(batch.iter()); } - pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { let batch = tokenize(message_body).map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); @@ -53,46 +51,53 @@ impl Data { }); for token in batch { - self.tokenids.remove(&token)?; + self.tokenids.remove(&token); } - - Ok(()) } - pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { + pub(super) async fn search_pdus( + &self, room_id: &RoomId, search_string: &str, + ) -> Option<(Vec>, Vec)> { let prefix = self .services .short - .get_shortroomid(room_id)? - .expect("room exists") + .get_shortroomid(room_id) + .await + .ok()? .to_be_bytes() .to_vec(); let words: Vec<_> = tokenize(search_string).collect(); - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.tokenids - .iter_from(&last_possible_id, true) // Newest pdus first - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(key, _)| key[prefix3.len()..].to_vec()) - }); - - let Some(common_elements) = utils::common_elements(iterators, |a, b| { - // We compare b with a because we reversed the iterator earlier - b.cmp(a) - }) else { - return Ok(None); - }; - - Ok(Some((Box::new(common_elements), words))) + let bufs: Vec<_> = words + .clone() + .into_iter() + .stream() + .then(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xFF); + let prefix3 = prefix2.clone(); + + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.tokenids + .rev_raw_keys_from(&last_possible_id) // Newest pdus first + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix2)) + .map(move |key| key[prefix3.len()..].to_vec()) + .collect::>() + }) + .collect() + .await; + + Some(( + set::intersection(bufs.iter().map(|buf| buf.iter())) + .cloned() + .collect(), + words, + )) } } @@ -100,7 +105,7 @@ impl Data { /// /// This may be used to tokenize both message bodies (for indexing) or search /// queries (for querying). -fn tokenize(body: &str) -> impl Iterator + '_ { +fn tokenize(body: &str) -> impl Iterator + Send + '_ { body.split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) .filter(|word| word.len() <= 50) diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 8caa0ce35..80b588044 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -21,20 +21,21 @@ impl crate::Service for Service { } impl Service { + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.index_pdu(shortroomid, pdu_id, message_body) + pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + self.db.index_pdu(shortroomid, pdu_id, message_body); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.deindex_pdu(shortroomid, pdu_id, message_body) + pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + self.db.deindex_pdu(shortroomid, pdu_id, message_body); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn search_pdus<'a>( - &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a, Vec)>> { - self.db.search_pdus(room_id, search_string) + pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec>, Vec)> { + self.db.search_pdus(room_id, search_string).await } } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 17fbb64e8..f6a824883 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use conduit::{utils, warn, Error, Result}; -use database::Map; +use conduit::{err, utils, Error, Result}; +use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{globals, Dep}; @@ -36,44 +36,46 @@ impl Data { } } - pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { - utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? - } else { - let shorteventid = self.services.globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - }; - - Ok(short) + pub(super) async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { + if let Ok(shorteventid) = self.eventid_shorteventid.qry(event_id).await.deserialized() { + return shorteventid; + } + + let shorteventid = self.services.globals.next_count().unwrap(); + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes()); + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes()); + + shorteventid } - pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { + pub(super) async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { let mut ret: Vec = Vec::with_capacity(event_ids.len()); let keys = event_ids .iter() .map(|id| id.as_bytes()) .collect::>(); + for (i, short) in self .eventid_shorteventid - .multi_get(&keys)? + .multi_get(keys.iter()) .iter() .enumerate() { #[allow(clippy::single_match_else)] match short { Some(short) => ret.push( - utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + utils::u64_from_bytes(short) + .map_err(|_| Error::bad_database("Invalid shorteventid in db.")) + .unwrap(), ), None => { - let short = self.services.globals.next_count()?; + let short = self.services.globals.next_count().unwrap(); self.eventid_shorteventid - .insert(keys[i], &short.to_be_bytes())?; + .insert(keys[i], &short.to_be_bytes()); self.shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i])?; + .insert(&short.to_be_bytes(), keys[i]); debug_assert!(ret.len() == i, "position of result must match input"); ret.push(short); @@ -81,115 +83,85 @@ impl Data { } } - Ok(ret) + ret } - pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = self - .statekey_shortstatekey - .get(&statekey_vec)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) - }) - .transpose()?; - - Ok(short) + pub(super) async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + let key = (event_type, state_key); + self.statekey_shortstatekey.qry(&key).await.deserialized() } - pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? - } else { - let shortstatekey = self.services.globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; - shortstatekey - }; - - Ok(short) - } + pub(super) async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { + let key = (event_type.to_string(), state_key); + if let Ok(shortstatekey) = self.statekey_shortstatekey.qry(&key).await.deserialized() { + return shortstatekey; + } - pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + let mut key = event_type.to_string().as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(state_key.as_bytes()); - let event_id = EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; + let shortstatekey = self.services.globals.next_count().unwrap(); + self.statekey_shortstatekey + .insert(&key, &shortstatekey.to_be_bytes()); + self.shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &key); - Ok(event_id) + shortstatekey } - pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - - let mut parts = bytes.splitn(2, |&b| b == 0xFF); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; + pub(super) async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + self.shorteventid_eventid + .qry(&shorteventid) + .await + .deserialized() + .map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}"))) + } - let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { - warn!("Event type in shortstatekey_statekey is invalid: {}", e); - Error::bad_database("Event type in shortstatekey_statekey is invalid.") - })?); + pub(super) async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + self.shortstatekey_statekey + .qry(&shortstatekey) + .await + .deserialized() + .map_err(|e| { + err!(Database( + "Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}" + )) + }) + } - let state_key = utils::string_from_bytes(statekey_bytes) - .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?; + /// Returns (shortstatehash, already_existed) + pub(super) async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { + if let Ok(shortstatehash) = self + .statehash_shortstatehash + .qry(state_hash) + .await + .deserialized() + { + return (shortstatehash, true); + } - let result = (event_type, state_key); + let shortstatehash = self.services.globals.next_count().unwrap(); + self.statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes()); - Ok(result) + (shortstatehash, false) } - /// Returns (shortstatehash, already_existed) - pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { - ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ) - } else { - let shortstatehash = self.services.globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - }) + pub(super) async fn get_shortroomid(&self, room_id: &RoomId) -> Result { + self.roomid_shortroomid.qry(room_id).await.deserialized() } - pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + pub(super) async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) - .transpose() - } - - pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { - utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? - } else { - let short = self.services.globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - }) + .qry(room_id) + .await + .deserialized() + .unwrap_or_else(|_| { + let short = self.services.globals.next_count().unwrap(); + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes()); + short + }) } } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index bfe0e9a0e..00bb7cb13 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -22,38 +22,40 @@ impl crate::Service for Service { } impl Service { - pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - self.db.get_or_create_shorteventid(event_id) + pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { + self.db.get_or_create_shorteventid(event_id).await } - pub fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { - self.db.multi_get_or_create_shorteventid(event_ids) + pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { + self.db.multi_get_or_create_shorteventid(event_ids).await } - pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - self.db.get_shortstatekey(event_type, state_key) + pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + self.db.get_shortstatekey(event_type, state_key).await } - pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - self.db.get_or_create_shortstatekey(event_type, state_key) + pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { + self.db + .get_or_create_shortstatekey(event_type, state_key) + .await } - pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - self.db.get_eventid_from_short(shorteventid) + pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + self.db.get_eventid_from_short(shorteventid).await } - pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - self.db.get_statekey_from_short(shortstatekey) + pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + self.db.get_statekey_from_short(shortstatekey).await } /// Returns (shortstatehash, already_existed) - pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - self.db.get_or_create_shortstatehash(state_hash) + pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { + self.db.get_or_create_shortstatehash(state_hash).await } - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } + pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { self.db.get_shortroomid(room_id).await } - pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - self.db.get_or_create_shortroomid(room_id) + pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { + self.db.get_or_create_shortroomid(room_id).await } } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 24d612d87..17fbf0ef0 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,7 +7,12 @@ use std::{ sync::Arc, }; -use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result}; +use conduit::{ + checked, debug, debug_info, err, + utils::{math::usize_from_f64, IterStream}, + Error, Result, +}; +use futures::{StreamExt, TryFutureExt}; use lru_cache::LruCache; use ruma::{ api::{ @@ -211,12 +216,15 @@ impl Service { .as_ref() { return Ok(if let Some(cached) = cached { - if self.is_accessible_child( - current_room, - &cached.summary.join_rule, - &identifier, - &cached.summary.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &cached.summary.join_rule, + &identifier, + &cached.summary.allowed_room_ids, + ) + .await + { Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone()))) } else { Some(SummaryAccessibility::Inaccessible) @@ -228,7 +236,9 @@ impl Service { Ok( if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { - let summary = self.get_room_summary(current_room, children_pdus, &identifier); + let summary = self + .get_room_summary(current_room, children_pdus, &identifier) + .await; if let Ok(summary) = summary { self.roomid_spacehierarchy_cache.lock().await.insert( current_room.clone(), @@ -322,12 +332,15 @@ impl Service { ); } } - if self.is_accessible_child( - current_room, - &response.room.join_rule, - &Identifier::UserId(user_id), - &response.room.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &response.room.join_rule, + &Identifier::UserId(user_id), + &response.room.allowed_room_ids, + ) + .await + { return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); } @@ -358,7 +371,7 @@ impl Service { } } - fn get_room_summary( + async fn get_room_summary( &self, current_room: &OwnedRoomId, children_state: Vec>, identifier: &Identifier<'_>, ) -> Result { @@ -367,48 +380,43 @@ impl Service { let join_rule = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { + .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .await + .map_or(JoinRule::Invite, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomJoinRulesEventContent| c.join_rule) .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); + .unwrap() + }); let allowed_room_ids = self .services .state_accessor .allowed_room_ids(join_rule.clone()); - if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { + if !self + .is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) + .await + { debug!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); } - let join_rule = join_rule.into(); - Ok(SpaceHierarchyParentSummary { canonical_alias: self .services .state_accessor .get_canonical_alias(room_id) - .unwrap_or(None), - name: self - .services - .state_accessor - .get_name(room_id) - .unwrap_or(None), + .await + .ok(), + name: self.services.state_accessor.get_name(room_id).await.ok(), num_joined_members: self .services .state_cache .room_joined_count(room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) + .await + .unwrap_or(0) .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), @@ -416,18 +424,29 @@ impl Service { .services .state_accessor .get_room_topic(room_id) - .unwrap_or(None), - world_readable: self.services.state_accessor.is_world_readable(room_id)?, - guest_can_join: self.services.state_accessor.guest_can_join(room_id)?, + .await + .ok(), + world_readable: self + .services + .state_accessor + .is_world_readable(room_id) + .await, + guest_can_join: self.services.state_accessor.guest_can_join(room_id).await, avatar_url: self .services .state_accessor - .get_avatar(room_id)? + .get_avatar(room_id) + .await .into_option() .unwrap_or_default() .url, - join_rule, - room_type: self.services.state_accessor.get_room_type(room_id)?, + join_rule: join_rule.into(), + room_type: self + .services + .state_accessor + .get_room_type(room_id) + .await + .ok(), children_state, allowed_room_ids, }) @@ -474,21 +493,22 @@ impl Service { results.push(summary_to_chunk(*summary.clone())); } else { children = children - .into_iter() - .rev() - .skip_while(|(room, _)| { - if let Ok(short) = self.services.short.get_shortroomid(room) - { - short.as_ref() != short_room_ids.get(parents.len()) - } else { - false - } - }) - .collect::>() - // skip_while doesn't implement DoubleEndedIterator, which is needed for rev - .into_iter() - .rev() - .collect(); + .iter() + .rev() + .stream() + .skip_while(|(room, _)| { + self.services + .short + .get_shortroomid(room) + .map_ok(|short| Some(&short) != short_room_ids.get(parents.len())) + .unwrap_or_else(|_| false) + }) + .map(Clone::clone) + .collect::)>>() + .await + .into_iter() + .rev() + .collect(); if children.is_empty() { return Err(Error::BadRequest( @@ -531,7 +551,7 @@ impl Service { let mut short_room_ids = vec![]; for room in parents { - short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?); + short_room_ids.push(self.services.short.get_or_create_shortroomid(&room).await); } Some( @@ -554,7 +574,7 @@ impl Service { async fn get_stripped_space_child_events( &self, room_id: &RoomId, ) -> Result>>, Error> { - let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else { + let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else { return Ok(None); }; @@ -562,10 +582,13 @@ impl Service { .services .state_accessor .state_full_ids(current_shortstatehash) - .await?; + .await + .map_err(|e| err!(Database("State in space not found: {e}")))?; + let mut children_pdus = Vec::new(); for (key, id) in state { - let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?; + let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?; + if event_type != StateEventType::SpaceChild { continue; } @@ -573,8 +596,9 @@ impl Service { let pdu = self .services .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + .get_pdu(&id) + .await + .map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?; if serde_json::from_str::(pdu.content.get()) .ok() @@ -593,7 +617,7 @@ impl Service { } /// With the given identifier, checks if a room is accessable - fn is_accessible_child( + async fn is_accessible_child( &self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, allowed_room_ids: &Vec, ) -> bool { @@ -607,6 +631,7 @@ impl Service { .services .event_handler .acl_check(server_name, room_id) + .await .is_err() { return false; @@ -617,12 +642,11 @@ impl Service { .services .state_cache .is_joined(user_id, current_room) - .unwrap_or_default() - || self - .services - .state_cache - .is_invited(user_id, current_room) - .unwrap_or_default() + .await || self + .services + .state_cache + .is_invited(user_id, current_room) + .await { return true; } @@ -633,22 +657,12 @@ impl Service { for room in allowed_room_ids { match identifier { Identifier::UserId(user) => { - if self - .services - .state_cache - .is_joined(user, room) - .unwrap_or_default() - { + if self.services.state_cache.is_joined(user, room).await { return true; } }, Identifier::ServerName(server) => { - if self - .services - .state_cache - .server_in_room(server, room) - .unwrap_or_default() - { + if self.services.state_cache.server_in_room(server, room).await { return true; } }, diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 3c110afc6..ccf7509a8 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,34 +1,31 @@ -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{EventId, OwnedEventId, RoomId}; +use conduit::{ + utils::{stream::TryIgnore, ReadyExt}, + Result, +}; +use database::{Database, Deserialized, Interfix, Map}; +use ruma::{OwnedEventId, RoomId}; use super::RoomMutexGuard; pub(super) struct Data { shorteventid_shortstatehash: Arc, - roomid_pduleaves: Arc, roomid_shortstatehash: Arc, + pub(super) roomid_pduleaves: Arc, } impl Data { pub(super) fn new(db: &Arc) -> Self { Self { shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), - roomid_pduleaves: db["roomid_pduleaves"].clone(), roomid_shortstatehash: db["roomid_shortstatehash"].clone(), + roomid_pduleaves: db["roomid_pduleaves"].clone(), } } - pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) + pub(super) async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { + self.roomid_shortstatehash.qry(room_id).await.deserialized() } #[inline] @@ -37,53 +34,35 @@ impl Data { room_id: &RoomId, new_shortstatehash: u64, _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - Ok(()) + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes()); } - pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { + pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) { self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - Ok(()) + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes()); } - pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - pub(super) fn set_forward_extremities( + pub(super) async fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { + let prefix = (room_id, Interfix); + self.roomid_pduleaves + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.roomid_pduleaves.remove(key)) + .await; + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - for event_id in event_ids { let mut key = prefix.clone(); key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; + self.roomid_pduleaves.insert(&key, event_id.as_bytes()); } - - Ok(()) } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index cb219bc03..c7f6605c7 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -7,12 +7,14 @@ use std::{ }; use conduit::{ - utils::{calculate_hash, MutexMap, MutexMapGuard}, - warn, Error, PduEvent, Result, + err, + utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard}, + warn, PduEvent, Result, }; use data::Data; +use database::{Ignore, Interfix}; +use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ - api::client::error::ErrorKind, events::{ room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, AnyStrippedStateEvent, StateEventType, TimelineEventType, @@ -81,14 +83,16 @@ impl Service { _statediffremoved: Arc>, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { - for event_id in statediffnew.iter().filter_map(|new| { + let event_ids = statediffnew.iter().stream().filter_map(|new| { self.services .state_compressor .parse_compressed_state_event(new) - .ok() - .map(|(_, id)| id) - }) { - let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else { + .map_ok_or_else(|_| None, |(_, event_id)| Some(event_id)) + }); + + pin_mut!(event_ids); + while let Some(event_id) = event_ids.next().await { + let Ok(pdu) = self.services.timeline.get_pdu_json(&event_id).await else { continue; }; @@ -113,15 +117,10 @@ impl Service { continue; }; - self.services.state_cache.update_membership( - room_id, - &user_id, - membership_event, - &pdu.sender, - None, - None, - false, - )?; + self.services + .state_cache + .update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false) + .await?; }, TimelineEventType::SpaceChild => { self.services @@ -135,10 +134,9 @@ impl Service { } } - self.services.state_cache.update_joined_count(room_id)?; + self.services.state_cache.update_joined_count(room_id).await; - self.db - .set_room_state(room_id, shortstatehash, state_lock)?; + self.db.set_room_state(room_id, shortstatehash, state_lock); Ok(()) } @@ -148,12 +146,16 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed), level = "debug")] - pub fn set_event_state( + pub async fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { - let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?; + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.db.get_room_shortstatehash(room_id).await; let state_hash = calculate_hash( &state_ids_compressed @@ -165,13 +167,18 @@ impl Service { let (shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if !already_existed { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed @@ -198,7 +205,7 @@ impl Service { )?; } - self.db.set_event_state(shorteventid, shortstatehash)?; + self.db.set_event_state(shorteventid, shortstatehash); Ok(shortstatehash) } @@ -208,34 +215,40 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] - pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = self .services .short - .get_or_create_shorteventid(&new_pdu.event_id)?; + .get_or_create_shorteventid(&new_pdu.event_id) + .await; - let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await; - if let Some(p) = previous_shortstatehash { - self.db.set_event_state(shorteventid, p)?; + if let Ok(p) = previous_shortstatehash { + self.db.set_event_state(shorteventid, p); } if let Some(state_key) = &new_pdu.state_key { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - #[inline] - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let shortstatekey = self .services .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) + .await; let new = self .services .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id)?; + .compress_state_event(shortstatekey, &new_pdu.event_id) + .await; let replaces = states_parents .last() @@ -276,49 +289,55 @@ impl Service { } #[tracing::instrument(skip(self, invite_event), level = "debug")] - pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { + pub async fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { let mut state = Vec::new(); // Add recommended events - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomCanonicalAlias, "") + .await + { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str()) + .await + { state.push(e.to_stripped_state_event()); } @@ -333,101 +352,108 @@ impl Service { room_id: &RoomId, shortstatehash: u64, mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.db.set_room_state(room_id, shortstatehash, mutex_lock) + ) { + self.db.set_room_state(room_id, shortstatehash, mutex_lock); } /// Returns the room's version. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = self - .services + pub async fn get_room_version(&self, room_id: &RoomId) -> Result { + self.services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: RoomCreateEventContent = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; - - Ok(create_event_content.room_version) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.room_version) + .map_err(|e| err!(Request(NotFound("No create event found: {e:?}")))) } #[inline] - pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.db.get_room_shortstatehash(room_id) + pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { + self.db.get_room_shortstatehash(room_id).await } - pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - self.db.get_forward_extremities(room_id) + pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + '_ { + let prefix = (room_id, Interfix); + + self.db + .roomid_pduleaves + .keys_prefix(&prefix) + .map_ok(|(_, event_id): (Ignore, &EventId)| event_id) + .ignore_err() } - pub fn set_forward_extremities( + pub async fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { self.db .set_forward_extremities(room_id, event_ids, state_lock) + .await; } /// This fetches auth events from the current state. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_auth_events( + pub async fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result>> { - let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else { + let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else { return Ok(HashMap::new()); }; - let auth_events = - state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object"); + let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)?; - let mut sauthevents = auth_events - .into_iter() + let mut sauthevents: HashMap<_, _> = auth_events + .iter() + .stream() .filter_map(|(event_type, state_key)| { self.services .short - .get_shortstatekey(&event_type.to_string().into(), &state_key) - .ok() - .flatten() - .map(|s| (s, (event_type, state_key))) + .get_shortstatekey(event_type, state_key) + .map_ok(move |s| (s, (event_type, state_key))) + .map(Result::ok) }) - .collect::>(); + .collect() + .await; let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| { + err!(Database( + "Missing shortstatehash info for {room_id:?} at {shortstatehash:?}: {e:?}" + )) + })? .pop() .expect("there is always one layer") .1; - Ok(full_state - .iter() - .filter_map(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - }) - .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) - .filter_map(|(k, event_id)| { - self.services - .timeline - .get_pdu(&event_id) - .ok() - .flatten() - .map(|pdu| (k, pdu)) - }) - .collect()) + let mut ret = HashMap::new(); + for compressed in full_state.iter() { + let Ok((shortstatekey, event_id)) = self + .services + .state_compressor + .parse_compressed_state_event(compressed) + .await + else { + continue; + }; + + let Some((ty, state_key)) = sauthevents.remove(&shortstatekey) else { + continue; + }; + + let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { + continue; + }; + + ret.insert((ty.to_owned(), state_key.to_owned()), pdu); + } + + Ok(ret) } } diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 4c85148db..79a983257 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, sync::Arc}; -use conduit::{utils, Error, PduEvent, Result}; -use database::Map; +use conduit::{err, PduEvent, Result}; +use database::{Deserialized, Map}; +use futures::TryFutureExt; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{rooms, Dep}; @@ -39,17 +40,22 @@ impl Data { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database("Missing state IDs: {e}")))? .pop() .expect("there is always one layer") .1; + let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { let parsed = self .services .state_compressor - .parse_compressed_state_event(compressed)?; + .parse_compressed_state_event(compressed) + .await?; + result.insert(parsed.0, parsed.1); i = i.wrapping_add(1); @@ -57,6 +63,7 @@ impl Data { tokio::task::yield_now().await; } } + Ok(result) } @@ -67,7 +74,8 @@ impl Data { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await? .pop() .expect("there is always one layer") .1; @@ -78,18 +86,13 @@ impl Data { let (_, eventid) = self .services .state_compressor - .parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); + .parse_compressed_state_event(compressed) + .await?; + + if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await { + if let Some(state_key) = pdu.state_key.as_ref() { + result.insert((pdu.kind.to_string().into(), state_key.clone()), pdu); + } } i = i.wrapping_add(1); @@ -101,61 +104,63 @@ impl Data { Ok(result) } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). #[allow(clippy::unused_self)] - pub(super) fn state_get_id( + pub(super) async fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - let Some(shortstatekey) = self + ) -> Result> { + let shortstatekey = self .services .short - .get_shortstatekey(event_type, state_key)? - else { - return Ok(None); - }; + .get_shortstatekey(event_type, state_key) + .await?; + let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))? .pop() .expect("there is always one layer") .1; - Ok(full_state + + let compressed = full_state .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) + .ok_or(err!(Database("No shortstatekey in compressed state")))?; + + self.services + .state_compressor + .parse_compressed_state_event(compressed) + .map_ok(|(_, id)| id) + .map_err(|e| { + err!(Database(error!( + ?event_type, + ?state_key, + ?shortstatekey, + "Failed to parse compressed: {e:?}" + ))) + }) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn state_get( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id)) + ) -> Result> { + self.state_get_id(shortstatehash, event_type, state_key) + .and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) + .await } /// Returns the state hash for this pdu. - pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash") - }) - }) - .transpose() - }) + .qry(event_id) + .and_then(|shorteventid| self.shorteventid_shortstatehash.qry(&shorteventid)) + .await + .deserialized() } /// Returns the full room state. @@ -163,34 +168,33 @@ impl Data { pub(super) async fn room_state_full( &self, room_id: &RoomId, ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_full(shortstatehash)) + .map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get_id( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key)) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key)) + .await } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 58fa31b3d..4c28483cb 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,8 +6,13 @@ use std::{ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result}; -use data::Data; +use conduit::{ + err, error, + pdu::PduBuilder, + utils::{math::usize_from_f64, ReadyExt}, + Error, PduEvent, Result, +}; +use futures::StreamExt; use lru_cache::LruCache; use ruma::{ events::{ @@ -31,8 +36,10 @@ use ruma::{ EventEncryptionAlgorithm, EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; +use serde::Deserialize; use serde_json::value::to_raw_value; +use self::data::Data; use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; pub struct Service { @@ -99,54 +106,58 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn state_get_id( + pub async fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.state_get_id(shortstatehash, event_type, state_key) + ) -> Result> { + self.db + .state_get_id(shortstatehash, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[inline] - pub fn state_get( + pub async fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.state_get(shortstatehash, event_type, state_key) + ) -> Result> { + self.db + .state_get(shortstatehash, event_type, state_key) + .await } /// Get membership for given user in state - fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result { - self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(MembershipState::Leave), |s| { + async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState { + self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_or(MembershipState::Leave, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomMemberEventContent| c.membership) .map_err(|_| Error::bad_database("Invalid room membership event in database.")) + .unwrap() }) } /// The user was a joined member at this state (potentially in the past) #[inline] - fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join) - // Return sensible default, i.e. - // false + async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { + self.user_membership(shortstatehash, user_id).await == MembershipState::Join } /// The user was an invited or joined room member at this state (potentially /// in the past) #[inline] - fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite) - // Return sensible default, i.e. false + async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { + let s = self.user_membership(shortstatehash, user_id).await; + s == MembershipState::Join || s == MembershipState::Invite } /// Whether a server is allowed to see an event through federation, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, origin, room_id, event_id))] - pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { + pub async fn server_can_see_event( + &self, origin: &ServerName, room_id: &RoomId, event_id: &EventId, + ) -> Result { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { return Ok(true); }; @@ -160,8 +171,9 @@ impl Service { } let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|e| { @@ -171,25 +183,28 @@ impl Service { ); Error::bad_database("Invalid history visibility event in database.") }) - }) - .unwrap_or(HistoryVisibility::Shared); + .unwrap() + }); - let mut current_server_members = self + let current_server_members = self .services .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member.server_name() == origin); + .ready_filter(|member| member.server_name() == origin); let visibility = match history_visibility { HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - current_server_members.any(|member| self.user_was_invited(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_invited(shortstatehash, member)) + .await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - current_server_members.any(|member| self.user_was_joined(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_joined(shortstatehash, member)) + .await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -208,9 +223,9 @@ impl Service { /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id, event_id))] - pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { - return Ok(true); + pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + return true; }; if let Some(visibility) = self @@ -219,14 +234,15 @@ impl Service { .unwrap() .get_mut(&(user_id.to_owned(), shortstatehash)) { - return Ok(*visibility); + return *visibility; } - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|e| { @@ -236,19 +252,19 @@ impl Service { ); Error::bad_database("Invalid history visibility event in database.") }) - }) - .unwrap_or(HistoryVisibility::Shared); + .unwrap() + }); let visibility = match history_visibility { HistoryVisibility::WorldReadable => true, HistoryVisibility::Shared => currently_member, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - self.user_was_invited(shortstatehash, user_id) + self.user_was_invited(shortstatehash, user_id).await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - self.user_was_joined(shortstatehash, user_id) + self.user_was_joined(shortstatehash, user_id).await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -261,17 +277,18 @@ impl Service { .unwrap() .insert((user_id.to_owned(), shortstatehash), visibility); - Ok(visibility) + visibility } /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] - pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "") + .await .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) @@ -285,11 +302,13 @@ impl Service { }) .unwrap_or(HistoryVisibility::Shared); - Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) + currently_member || history_visibility == HistoryVisibility::WorldReadable } /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.db.pdu_shortstatehash(event_id) } + pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { + self.db.pdu_shortstatehash(event_id).await + } /// Returns the full room state. #[tracing::instrument(skip(self), level = "debug")] @@ -300,47 +319,61 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get_id( + pub async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.room_state_get_id(room_id, event_type, state_key) + ) -> Result> { + self.db + .room_state_get_id(room_id, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get( + pub async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.room_state_get(room_id, event_type, state_key) + ) -> Result> { + self.db.room_state_get(room_id, event_type, state_key).await } - pub fn get_name(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomName, "")? - .map_or(Ok(None), |s| { - Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name))) - }) + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub async fn room_state_get_content( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result + where + T: for<'de> Deserialize<'de> + Send, + { + use serde_json::from_str; + + self.room_state_get(room_id, event_type, state_key) + .await + .and_then(|event| from_str::(event.content.get()).map_err(Into::into)) } - pub fn get_avatar(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map_or(Ok(ruma::JsOption::Undefined), |s| { + pub async fn get_name(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomName, "") + .await + .map(|c: RoomNameEventContent| c.name) + } + + pub async fn get_avatar(&self, room_id: &RoomId) -> ruma::JsOption { + self.room_state_get(room_id, &StateEventType::RoomAvatar, "") + .await + .map_or(ruma::JsOption::Undefined, |s| { serde_json::from_str(s.content.get()) .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + .unwrap() }) } - pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room member event in database.")) - }) + pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await } - pub fn user_can_invite( + pub async fn user_can_invite( &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, - ) -> Result { + ) -> bool { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) .expect("Event content always serializes"); @@ -353,122 +386,101 @@ impl Service { timestamp: None, }; - Ok(self - .services + self.services .timeline .create_hash_and_sign_event(new_event, sender, room_id, state_lock) - .is_ok()) + .await + .is_ok() } /// Checks if guests are able to view room content without joining - pub fn is_world_readable(&self, room_id: &RoomId) -> Result { - self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable - }) - .map_err(|e| { - error!( - "Invalid room history visibility event in database for room {room_id}, assuming not world \ - readable: {e} " - ); - Error::bad_database("Invalid room history visibility event in database.") - }) - }) + pub async fn is_world_readable(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") + .await + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable) + .unwrap_or(false) } /// Checks if guests are able to join a given room - pub fn guest_can_join(&self, room_id: &RoomId) -> Result { - self.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) - .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) - }) + pub async fn guest_can_join(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomGuestAccess, "") + .await + .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) + .unwrap_or(false) } /// Gets the primary alias from canonical alias event - pub fn get_canonical_alias(&self, room_id: &RoomId) -> Result, Error> { - self.room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomCanonicalAliasEventContent| c.alias) - .map_err(|_| Error::bad_database("Invalid canonical alias event in database.")) + pub async fn get_canonical_alias(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomCanonicalAlias, "") + .await + .and_then(|c: RoomCanonicalAliasEventContent| { + c.alias + .ok_or_else(|| err!(Request(NotFound("No alias found in event content.")))) }) } /// Gets the room topic - pub fn get_room_topic(&self, room_id: &RoomId) -> Result, Error> { - self.room_state_get(room_id, &StateEventType::RoomTopic, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomTopicEventContent| Some(c.topic)) - .map_err(|e| { - error!("Invalid room topic event in database for room {room_id}: {e}"); - Error::bad_database("Invalid room topic event in database.") - }) - }) + pub async fn get_room_topic(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomTopic, "") + .await + .map(|c: RoomTopicEventContent| c.topic) } /// Checks if a given user can redact a given event /// /// If federation is true, it allows redaction events from any user of the /// same server as the original event sender - pub fn user_can_redact( + pub async fn user_can_redact( &self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool, ) -> Result { - self.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map_or_else( - || { - // Falling back on m.room.create to judge power level - if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? { - Ok(pdu.sender == sender - || if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - pdu.sender == sender - } else { - false - }) + if let Ok(event) = self + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await + { + let Ok(event) = serde_json::from_str(event.content.get()) + .map(|content: RoomPowerLevelsEventContent| content.into()) + .map(|event: RoomPowerLevels| event) + else { + return Ok(false); + }; + + Ok(event.user_can_redact_event_of_other(sender) + || event.user_can_redact_own_event(sender) + && if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { + if federation { + pdu.sender.server_name() == sender.server_name() + } else { + pdu.sender == sender + } + } else { + false + }) + } else { + // Falling back on m.room.create to judge power level + if let Ok(pdu) = self + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + { + Ok(pdu.sender == sender + || if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { + pdu.sender == sender } else { - Err(Error::bad_database( - "No m.room.power_levels or m.room.create events in database for room", - )) - } - }, - |event| { - serde_json::from_str(event.content.get()) - .map(|content: RoomPowerLevelsEventContent| content.into()) - .map(|event: RoomPowerLevels| { - event.user_can_redact_event_of_other(sender) - || event.user_can_redact_own_event(sender) - && if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - if federation { - pdu.sender.server_name() == sender.server_name() - } else { - pdu.sender == sender - } - } else { - false - } - }) - .map_err(|_| Error::bad_database("Invalid m.room.power_levels event in database")) - }, - ) + false + }) + } else { + Err(Error::bad_database( + "No m.room.power_levels or m.room.create events in database for room", + )) + } + } } /// Returns the join rule (`SpaceRoomJoinRule`) for a given room - pub fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec), Error> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| { - (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)) - }) - .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) + pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec)> { + self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))) + .or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![]))) } /// Returns an empty vec if not a restricted room @@ -487,25 +499,21 @@ impl Service { room_ids } - pub fn get_room_type(&self, room_id: &RoomId) -> Result> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .map(|s| { - serde_json::from_str::(s.content.get()) - .map_err(|e| err!(Database(error!("Invalid room create event in database: {e}")))) + pub async fn get_room_type(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .and_then(|content: RoomCreateEventContent| { + content + .room_type + .ok_or_else(|| err!(Request(NotFound("No type found in event content")))) }) - .transpose()? - .and_then(|e| e.room_type)) } /// Gets the room's encryption algorithm if `m.room.encryption` state event /// is found - pub fn get_room_encryption(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomEncryption, "")? - .map_or(Ok(None), |s| { - serde_json::from_str::(s.content.get()) - .map(|content| Some(content.algorithm)) - .map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}")))) - }) + pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "") + .await + .map(|content: RoomEncryptionEventContent| content.algorithm) } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 19c73ea1c..38e504f6b 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,43 +1,42 @@ use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, sync::{Arc, RwLock}, }; -use conduit::{utils, Error, Result}; -use database::Map; -use itertools::Itertools; +use conduit::{utils, utils::stream::TryIgnore, Error, Result}; +use database::{Deserialized, Interfix, Map}; +use futures::{Stream, StreamExt}; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, RoomId, UserId, }; -use crate::{appservice::RegistrationInfo, globals, users, Dep}; +use crate::{globals, Dep}; -type StrippedStateEventIter<'a> = Box>)>> + 'a>; -type AnySyncStateEventIter<'a> = Box>)>> + 'a>; type AppServiceInRoomCache = RwLock>>; +type StrippedStateEventItem = (OwnedRoomId, Vec>); +type SyncStateEventItem = (OwnedRoomId, Vec>); pub(super) struct Data { pub(super) appservice_in_room_cache: AppServiceInRoomCache, - roomid_invitedcount: Arc, - roomid_inviteviaservers: Arc, - roomid_joinedcount: Arc, - roomserverids: Arc, - roomuserid_invitecount: Arc, - roomuserid_joined: Arc, - roomuserid_leftcount: Arc, - roomuseroncejoinedids: Arc, - serverroomids: Arc, - userroomid_invitestate: Arc, - userroomid_joined: Arc, - userroomid_leftstate: Arc, + pub(super) roomid_invitedcount: Arc, + pub(super) roomid_inviteviaservers: Arc, + pub(super) roomid_joinedcount: Arc, + pub(super) roomserverids: Arc, + pub(super) roomuserid_invitecount: Arc, + pub(super) roomuserid_joined: Arc, + pub(super) roomuserid_leftcount: Arc, + pub(super) roomuseroncejoinedids: Arc, + pub(super) serverroomids: Arc, + pub(super) userroomid_invitestate: Arc, + pub(super) userroomid_joined: Arc, + pub(super) userroomid_leftstate: Arc, services: Services, } struct Services { globals: Dep, - users: Dep, } impl Data { @@ -59,19 +58,18 @@ impl Data { userroomid_leftstate: db["userroomid_leftstate"].clone(), services: Services { globals: args.depend::("globals"), - users: args.depend::("users"), }, } } - pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[]) + self.roomuseroncejoinedids.insert(&userroom_id, &[]); } - pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -82,64 +80,17 @@ impl Data { userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; + self.userroomid_joined.insert(&userroom_id, &[]); + self.roomuserid_joined.insert(&roomuser_id, &[]); + self.userroomid_invitestate.remove(&userroom_id); + self.roomuserid_invitecount.remove(&roomuser_id); + self.userroomid_leftstate.remove(&userroom_id); + self.roomuserid_leftcount.remove(&roomuser_id); - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) + self.roomid_inviteviaservers.remove(&roomid); } - pub(super) fn mark_as_invited( - &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, - invite_via: Option>, - ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), - )?; - self.roomuserid_invitecount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - if let Some(servers) = invite_via { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - #[allow(clippy::redundant_clone)] // this is a necessary clone? - prev_servers.append(servers.clone().as_mut()); - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - } - - Ok(()) - } - - pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -153,115 +104,20 @@ impl Data { self.userroomid_leftstate.insert( &userroom_id, &serde_json::to_vec(&Vec::>::new()).unwrap(), - )?; // TODO + ); // TODO self.roomuserid_leftcount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) - } - - pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut joined_servers = HashSet::new(); - - for joined in self.room_members(room_id).filter_map(Result::ok) { - joined_servers.insert(joined.server_name().to_owned()); - joinedcount = joinedcount.saturating_add(1); - } - - for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { - invitedcount = invitedcount.saturating_add(1); - } - - self.roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - - self.roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - - for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { - let maybe = self - .appservice_in_room_cache - .read() - .unwrap() - .get(room_id) - .and_then(|map| map.get(&appservice.registration.id)) - .copied(); - - if let Some(b) = maybe { - Ok(b) - } else { - let bridge_user_id = UserId::parse_with_server_name( - appservice.registration.sender_localpart.as_str(), - self.services.globals.server_name(), - ) - .ok(); - - let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self - .room_members(room_id) - .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); - - self.appservice_in_room_cache - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.registration.id.clone(), in_room); + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + self.userroomid_joined.remove(&userroom_id); + self.roomuserid_joined.remove(&roomuser_id); + self.userroomid_invitestate.remove(&userroom_id); + self.roomuserid_invitecount.remove(&roomuser_id); - Ok(in_room) - } + self.roomid_inviteviaservers.remove(&roomid); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -270,397 +126,69 @@ impl Data { roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_servers<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { - ServerName::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) - })) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - let mut key = server.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.serverroomids.get(&key).map(|o| o.is_some()) - } - - /// Returns an iterator of all rooms a server participates in (as far as we - /// know). - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_rooms<'a>( - &'a self, server: &ServerName, - ) -> Box> + 'a> { - let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) - })) - } - - /// Returns an iterator of all joined members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + Send + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) - })) - } - - /// Returns an iterator of all our local users in the room, even if they're - /// deactivated/guests - pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { - Box::new( - self.room_members(room_id) - .filter_map(Result::ok) - .filter(|user| self.services.globals.user_is_local(user)), - ) - } - - /// Returns an iterator of all our local joined users in a room who are - /// active (not deactivated, not guest) - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn active_local_users_in_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box + 'a> { - Box::new( - self.local_users_in_room(room_id) - .filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)), - ) - } - - /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result> { - self.roomid_joinedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result> { - self.roomid_invitedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_useroncejoined<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }), - ) - } - - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members_invited<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, - )) - }) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_leftcount - .get(&key)? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db."))) - .transpose() - } - - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box> + '_> { - Box::new( - self.userroomid_joined - .scan_prefix(user_id.as_bytes().to_vec()) - .map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - }), - ) + self.userroomid_leftstate.remove(&userroom_id); + self.roomuserid_leftcount.remove(&roomuser_id); } /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok((room_id, state)) - }), - ) + #[inline] + pub(super) fn rooms_invited<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.userroomid_invitestate + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); + let room_id = utils::string_from_bytes(room_id).unwrap(); + let room_id = RoomId::parse(room_id).unwrap(); + let state = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate.")) + .unwrap(); + + (room_id, state) + }) } #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn invite_state( + pub(super) async fn invite_state( &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - + ) -> Result>> { + let key = (user_id, room_id); self.userroomid_invitestate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok(state) - }) - .transpose() + .qry(&key) + .await + .deserialized_json() } #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn left_state( + pub(super) async fn left_state( &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - + ) -> Result>> { + let key = (user_id, room_id); self.userroomid_leftstate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok(state) - }) - .transpose() + .qry(&key) + .await + .deserialized_json() } /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn servers_invite_via<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let key = room_id.as_bytes().to_vec(); - - Box::new( - self.roomid_inviteviaservers - .scan_prefix(key) - .map(|(_, servers)| { - ServerName::parse( - utils::string_from_bytes( - servers - .rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - prev_servers.extend(servers.to_owned()); - prev_servers.sort_unstable(); - prev_servers.dedup(); - - let servers = prev_servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - - Ok(()) + #[inline] + pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.userroomid_leftstate + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); + let room_id = utils::string_from_bytes(room_id).unwrap(); + let room_id = RoomId::parse(room_id).unwrap(); + let state = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate.")) + .unwrap(); + + (room_id, state) + }) } } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 71899ceb9..ce5b024b7 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,9 +1,15 @@ mod data; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use conduit::{err, error, warn, Error, Result}; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + warn, Result, +}; use data::Data; +use database::{Deserialized, Ignore, Interfix}; +use futures::{Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ @@ -18,7 +24,7 @@ use ruma::{ }, int, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, }; use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep}; @@ -55,7 +61,7 @@ impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] #[allow(clippy::too_many_arguments)] - pub fn update_membership( + pub async fn update_membership( &self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId, last_state: Option>>, invite_via: Option>, update_joined_count: bool, @@ -68,7 +74,7 @@ impl Service { // update #[allow(clippy::collapsible_if)] if !self.services.globals.user_is_local(user_id) { - if !self.services.users.exists(user_id)? { + if !self.services.users.exists(user_id).await { self.services.users.create(user_id, None)?; } @@ -100,17 +106,17 @@ impl Service { match &membership { MembershipState::Join => { // Check if the user never joined this room - if !self.once_joined(user_id, room_id)? { + if !self.once_joined(user_id, room_id).await { // Add the user ID to the join list then - self.db.mark_as_once_joined(user_id, room_id)?; + self.db.mark_as_once_joined(user_id, room_id); // Check if the room has a predecessor - if let Some(predecessor) = self + if let Ok(Some(predecessor)) = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .and_then(|create| serde_json::from_str(create.content.get()).ok()) - .and_then(|content: RoomCreateEventContent| content.predecessor) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.predecessor) { // Copy user settings from predecessor to the current room: // - Push rules @@ -138,32 +144,33 @@ impl Service { // .ok(); // Copy old tags to new room - if let Some(tag_event) = self + if let Ok(tag_event) = self .services .account_data - .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| { + .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag) + .await + .and_then(|event| { serde_json::from_str(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { self.services .account_data - .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) + .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event) + .await .ok(); }; // Copy direct chat flag - if let Some(direct_event) = self + if let Ok(mut direct_event) = self .services .account_data - .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? - .map(|event| { + .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into()) + .await + .and_then(|event| { serde_json::from_str::(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { - let mut direct_event = direct_event?; let mut room_ids_updated = false; - for room_ids in direct_event.content.0.values_mut() { if room_ids.iter().any(|r| r == &predecessor.room_id) { room_ids.push(room_id.to_owned()); @@ -172,18 +179,21 @@ impl Service { } if room_ids_updated { - self.services.account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event).expect("to json always works"), - )?; + self.services + .account_data + .update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event).expect("to json always works"), + ) + .await?; } }; } } - self.db.mark_as_joined(user_id, room_id)?; + self.db.mark_as_joined(user_id, room_id); }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver @@ -196,12 +206,12 @@ impl Service { GlobalAccountDataEventType::IgnoredUserList .to_string() .into(), - )? - .map(|event| { + ) + .await + .and_then(|event| { serde_json::from_str::(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) - .transpose()? .map_or(false, |ignored| { ignored .content @@ -214,194 +224,282 @@ impl Service { return Ok(()); } - self.db - .mark_as_invited(user_id, room_id, last_state, invite_via)?; + self.mark_as_invited(user_id, room_id, last_state, invite_via) + .await; }, MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id)?; + self.db.mark_as_left(user_id, room_id); }, _ => {}, } if update_joined_count { - self.update_joined_count(room_id)?; + self.update_joined_count(room_id).await; } Ok(()) } - #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { - self.db.appservice_in_room(room_id, appservice) + pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool { + let maybe = self + .db + .appservice_in_room_cache + .read() + .unwrap() + .get(room_id) + .and_then(|map| map.get(&appservice.registration.id)) + .copied(); + + if let Some(b) = maybe { + b + } else { + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + self.services.globals.server_name(), + ) + .ok(); + + let in_room = if let Some(id) = &bridge_user_id { + self.is_joined(id, room_id).await + } else { + false + }; + + let in_room = in_room + || self + .room_members(room_id) + .ready_any(|userid| appservice.users.is_match(userid.as_str())) + .await; + + self.db + .appservice_in_room_cache + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default() + .insert(appservice.registration.id.clone(), in_room); + + in_room + } } /// Direct DB function to directly mark a user as left. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_left(user_id, room_id) - } + pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_left(user_id, room_id); } /// Direct DB function to directly mark a user as joined. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_joined(user_id, room_id) - } + pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_joined(user_id, room_id); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { self.db.forget(room_id, user_id); } /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_servers(room_id) + pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomserverids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, server): (Ignore, &ServerName)| server) } #[tracing::instrument(skip(self), level = "debug")] - pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - self.db.server_in_room(server, room_id) + pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool { + let key = (server, room_id); + self.db.serverroomids.qry(&key).await.is_ok() } /// Returns an iterator of all rooms a server participates in (as far as we /// know). #[tracing::instrument(skip(self), level = "debug")] - pub fn server_rooms(&self, server: &ServerName) -> impl Iterator> + '_ { - self.db.server_rooms(server) + pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream + Send + 'a { + let prefix = (server, Interfix); + self.db + .serverroomids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns true if server can see user by sharing at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result { - Ok(self - .server_rooms(server) - .filter_map(Result::ok) - .any(|room_id: OwnedRoomId| self.is_joined(user_id, &room_id).unwrap_or(false))) + pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool { + self.server_rooms(server) + .any(|room_id| self.is_joined(user_id, room_id)) + .await } /// Returns true if user_a and user_b share at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result { + pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool { // Minimize number of point-queries by iterating user with least nr rooms - let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() { + let (a, b) = if self.rooms_joined(user_a).count().await < self.rooms_joined(user_b).count().await { (user_a, user_b) } else { (user_b, user_a) }; - Ok(self - .rooms_joined(a) - .filter_map(Result::ok) - .any(|room_id| self.is_joined(b, &room_id).unwrap_or(false))) + self.rooms_joined(a) + .any(|room_id| self.is_joined(b, room_id)) + .await } - /// Returns an iterator over all joined members of a room. + /// Returns an iterator of all joined members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> + Send + '_ { - self.db.room_members(room_id) + pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_joined + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } + pub async fn room_joined_count(&self, room_id: &RoomId) -> Result { + self.db.roomid_joinedcount.qry(room_id).await.deserialized() + } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local users in the room, even if they're /// deactivated/guests - pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { - self.db.local_users_in_room(room_id) + pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + self.room_members(room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) - pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { - self.db.active_local_users_in_room(room_id) + pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + self.local_users_in_room(room_id) + .filter(|user| self.services.users.is_active(user)) } /// Returns the number of users which are currently invited to a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) } + pub async fn room_invited_count(&self, room_id: &RoomId) -> Result { + self.db + .roomid_invitedcount + .qry(room_id) + .await + .deserialized() + } /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_useroncejoined(room_id) + pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuseroncejoinedids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_members_invited(room_id) + pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_invitecount + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_invite_count(room_id, user_id) + pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db + .roomuserid_invitecount + .qry(&key) + .await + .deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_left_count(room_id, user_id) + pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db.roomuserid_leftcount.qry(&key).await.deserialized() } /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator> + '_ { - self.db.rooms_joined(user_id) + pub fn rooms_joined(&self, user_id: &UserId) -> impl Stream + Send { + self.db + .userroomid_joined + .keys_prefix(user_id) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_invited( - &self, user_id: &UserId, - ) -> impl Iterator>)>> + '_ { + pub fn rooms_invited<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream>)> + Send + 'a { self.db.rooms_invited(user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - self.db.invite_state(user_id, room_id) + pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { + self.db.invite_state(user_id, room_id).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - self.db.left_state(user_id, room_id) + pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { + self.db.left_state(user_id, room_id).await } /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_left( - &self, user_id: &UserId, - ) -> impl Iterator>)>> + '_ { + pub fn rooms_left<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream>)> + Send + 'a { self.db.rooms_left(user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.once_joined(user_id, room_id) + pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_joined(user_id, room_id) } + pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_joined.qry(&key).await.is_ok() + } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_invited(user_id, room_id) + pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_invitestate.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_leftstate.qry(&key).await.is_ok() + } #[tracing::instrument(skip(self), level = "debug")] - pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.servers_invite_via(room_id) + pub fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> impl Stream + Send + 'a { + self.db + .roomid_inviteviaservers + .stream_prefix(room_id) + .ignore_err() + .map(|(_, servers): (Ignore, Vec<&ServerName>)| &**(servers.last().expect("at least one servername"))) } /// Gets up to three servers that are likely to be in the room in the @@ -409,37 +507,27 @@ impl Service { /// /// See #[tracing::instrument(skip(self))] - pub fn servers_route_via(&self, room_id: &RoomId) -> Result> { + pub async fn servers_route_via(&self, room_id: &RoomId) -> Result> { let most_powerful_user_server = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map(|pdu| { - serde_json::from_str(pdu.content.get()).map(|conent: RoomPowerLevelsEventContent| { - conent - .users - .iter() - .max_by_key(|(_, power)| *power) - .and_then(|x| { - if x.1 >= &int!(50) { - Some(x) - } else { - None - } - }) - .map(|(user, _power)| user.server_name().to_owned()) - }) + .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") + .await + .map(|content: RoomPowerLevelsEventContent| { + content + .users + .iter() + .max_by_key(|(_, power)| *power) + .and_then(|x| (x.1 >= &int!(50)).then_some(x)) + .map(|(user, _power)| user.server_name().to_owned()) }) - .transpose() - .map_err(|e| { - error!("Invalid power levels event content in database: {e}"); - Error::bad_database("Invalid power levels event content in database") - })? - .flatten(); + .map_err(|e| err!(Database(error!(?e, "Invalid power levels event content in database."))))?; let mut servers: Vec = self .room_members(room_id) - .filter_map(Result::ok) + .collect::>() + .await + .iter() .counts_by(|user| user.server_name().to_owned()) .iter() .sorted_by_key(|(_, users)| *users) @@ -468,4 +556,139 @@ impl Service { .expect("locked") .clear(); } + + pub async fn update_joined_count(&self, room_id: &RoomId) { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut joined_servers = HashSet::new(); + + self.room_members(room_id) + .ready_for_each(|joined| { + joined_servers.insert(joined.server_name().to_owned()); + joinedcount = joinedcount.saturating_add(1); + }) + .await; + + invitedcount = invitedcount.saturating_add( + self.room_members_invited(room_id) + .count() + .await + .try_into() + .unwrap_or(0), + ); + + self.db + .roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes()); + + self.db + .roomid_invitedcount + .insert(room_id.as_bytes(), &invitedcount.to_be_bytes()); + + self.room_servers(room_id) + .ready_for_each(|old_joined_server| { + if !joined_servers.remove(old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.roomserverids.remove(&roomserver_id); + self.db.serverroomids.remove(&serverroom_id); + } + }) + .await; + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.roomserverids.insert(&roomserver_id, &[]); + self.db.serverroomids.insert(&serverroom_id, &[]); + } + + self.db + .appservice_in_room_cache + .write() + .unwrap() + .remove(room_id); + } + + pub async fn mark_as_invited( + &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, + ) { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.userroomid_invitestate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), + ); + self.db + .roomuserid_invitecount + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + if let Some(servers) = invite_via { + let mut prev_servers = self + .servers_invite_via(room_id) + .map(ToOwned::to_owned) + .collect::>() + .await; + #[allow(clippy::redundant_clone)] // this is a necessary clone? + prev_servers.append(servers.clone().as_mut()); + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.db + .roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers); + } + } + + #[tracing::instrument(skip(self), level = "debug")] + pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) { + let mut prev_servers = self + .servers_invite_via(room_id) + .map(ToOwned::to_owned) + .collect::>() + .await; + prev_servers.extend(servers.to_owned()); + prev_servers.sort_unstable(); + prev_servers.dedup(); + + let servers = prev_servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.db + .roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers); + } } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 337730019..9a9f70a28 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, Result}; +use conduit::{err, expected, utils, Result}; use database::{Database, Map}; use super::CompressedStateEvent; @@ -22,11 +22,13 @@ impl Data { } } - pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result { + pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result { let value = self .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + .qry(&shortstatehash) + .await + .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; + let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); let parent = if parent != 0 { Some(parent) @@ -40,10 +42,10 @@ impl Data { let stride = size_of::(); let mut i = stride; - while let Some(v) = value.get(i..checked!(i + 2 * stride)?) { + while let Some(v) = value.get(i..expected!(i + 2 * stride)) { if add_mode && v.starts_with(&0_u64.to_be_bytes()) { add_mode = false; - i = checked!(i + stride)?; + i = expected!(i + stride); continue; } if add_mode { @@ -51,7 +53,7 @@ impl Data { } else { removed.insert(v.try_into().expect("we checked the size above")); } - i = checked!(i + 2 * stride)?; + i = expected!(i + 2 * stride); } Ok(StateDiff { @@ -61,7 +63,7 @@ impl Data { }) } - pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) -> Result<()> { + pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); for new in diff.added.iter() { value.extend_from_slice(&new[..]); @@ -75,6 +77,6 @@ impl Data { } self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value) + .insert(&shortstatehash.to_be_bytes(), &value); } } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 2550774e1..cd3f2f738 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -27,14 +27,12 @@ type StateInfoLruCache = Mutex< >, >; -type ShortStateInfoResult = Result< - Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed - )>, ->; +type ShortStateInfoResult = Vec<( + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed +)>; type ParentStatesVec = Vec<( u64, // sstatehash @@ -43,7 +41,7 @@ type ParentStatesVec = Vec<( Arc>, // removed )>; -type HashSetCompressStateEvent = Result<(u64, Arc>, Arc>)>; +type HashSetCompressStateEvent = (u64, Arc>, Arc>); pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Service { @@ -86,12 +84,11 @@ impl crate::Service for Service { impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self), level = "debug")] - pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { + pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { if let Some(r) = self .stateinfo_cache .lock() - .unwrap() + .expect("locked") .get_mut(&shortstatehash) { return Ok(r.clone()); @@ -101,11 +98,11 @@ impl Service { parent, added, removed, - } = self.db.get_statediff(shortstatehash)?; + } = self.db.get_statediff(shortstatehash).await?; if let Some(parent) = parent { - let mut response = self.load_shortstatehash_info(parent)?; - let mut state = (*response.last().unwrap().1).clone(); + let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; + let mut state = (*response.last().expect("at least one response").1).clone(); state.extend(added.iter().copied()); let removed = (*removed).clone(); for r in &removed { @@ -116,7 +113,7 @@ impl Service { self.stateinfo_cache .lock() - .unwrap() + .expect("locked") .insert(shortstatehash, response.clone()); Ok(response) @@ -124,33 +121,42 @@ impl Service { let response = vec![(shortstatehash, added.clone(), added, removed)]; self.stateinfo_cache .lock() - .unwrap() + .expect("locked") .insert(shortstatehash, response.clone()); + Ok(response) } } - pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { + pub async fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> CompressedStateEvent { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( &self .services .short - .get_or_create_shorteventid(event_id)? + .get_or_create_shorteventid(event_id) + .await .to_be_bytes(), ); - Ok(v.try_into().expect("we checked the size above")) + + v.try_into().expect("we checked the size above") } /// Returns shortstatekey, event id #[inline] - pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc)> { - Ok(( - utils::u64_from_bytes(&compressed_event[0..size_of::()]).expect("bytes have right length"), - self.services.short.get_eventid_from_short( - utils::u64_from_bytes(&compressed_event[size_of::()..]).expect("bytes have right length"), - )?, - )) + pub async fn parse_compressed_state_event( + &self, compressed_event: &CompressedStateEvent, + ) -> Result<(u64, Arc)> { + use utils::u64_from_u8; + + let shortstatekey = u64_from_u8(&compressed_event[0..size_of::()]); + let event_id = self + .services + .short + .get_eventid_from_short(u64_from_u8(&compressed_event[size_of::()..])) + .await?; + + Ok((shortstatekey, event_id)) } /// Creates a new shortstatehash that often is just a diff to an already @@ -227,7 +233,7 @@ impl Service { added: statediffnew, removed: statediffremoved, }, - )?; + ); return Ok(()); }; @@ -280,7 +286,7 @@ impl Service { added: statediffnew, removed: statediffremoved, }, - )?; + ); } Ok(()) @@ -288,10 +294,15 @@ impl Service { /// Returns the new shortstatehash, and the state diff from the previous /// room state - pub fn save_state( + pub async fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, - ) -> HashSetCompressStateEvent { - let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?; + ) -> Result { + let previous_shortstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .ok(); let state_hash = utils::calculate_hash( &new_state_ids_compressed @@ -303,14 +314,18 @@ impl Service { let (new_shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if Some(new_shortstatehash) == previous_shortstatehash { return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); } - let states_parents = - previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + let states_parents = if let Some(p) = previous_shortstatehash { + self.load_shortstatehash_info(p).await.unwrap_or_default() + } else { + ShortStateInfoResult::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = new_state_ids_compressed diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index fb279a007..f50b812ca 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,13 +1,18 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, PduEvent, Result}; -use database::Map; +use conduit::{ + checked, + result::LogErr, + utils, + utils::{stream::TryIgnore, ReadyExt}, + PduEvent, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; use crate::{rooms, Dep}; -type PduEventIterResult<'a> = Result> + 'a>>; - pub(super) struct Data { threadid_userids: Arc, services: Services, @@ -30,38 +35,37 @@ impl Data { } } - pub(super) fn threads_until<'a>( + pub(super) async fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> PduEventIterResult<'a> { + ) -> Result + Send + 'a> { let prefix = self .services .short - .get_shortroomid(room_id)? - .expect("room exists") + .get_shortroomid(room_id) + .await? .to_be_bytes() .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); - Ok(Box::new( - self.threadid_userids - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pduid, _users)| { - let count = utils::u64_from_bytes(&pduid[(size_of::())..]) - .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((count, pdu)) - }), - )) + let stream = self + .threadid_userids + .rev_raw_keys_from(¤t) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|pduid| (utils::u64_from_u8(&pduid[(size_of::())..]), pduid)) + .filter_map(move |(count, pduid)| async move { + let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?; + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + Some((count, pdu)) + }); + + Ok(stream) } pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { @@ -71,28 +75,12 @@ impl Data { .collect::>() .join(&[0xFF][..]); - self.threadid_userids.insert(root_id, &users)?; + self.threadid_userids.insert(root_id, &users); Ok(()) } - pub(super) fn get_participants(&self, root_id: &[u8]) -> Result>> { - if let Some(users) = self.threadid_userids.get(root_id)? { - Ok(Some( - users - .split(|b| *b == 0xFF) - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(bytes) - .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?, - ) - .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) - }) - .filter_map(Result::ok) - .collect(), - )) - } else { - Ok(None) - } + pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result> { + self.threadid_userids.qry(root_id).await.deserialized() } } diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index ae51cd0f9..2eafe5d52 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,12 +2,12 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Error, PduEvent, Result}; +use conduit::{err, PduEvent, Result}; use data::Data; +use futures::Stream; use ruma::{ - api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, - events::relation::BundledThread, - uint, CanonicalJsonValue, EventId, RoomId, UserId, + api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue, + EventId, RoomId, UserId, }; use serde_json::json; @@ -36,30 +36,35 @@ impl crate::Service for Service { } impl Service { - pub fn threads_until<'a>( + pub async fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> Result> + 'a> { - self.db.threads_until(user_id, room_id, until, include) + ) -> Result + Send + 'a> { + self.db + .threads_until(user_id, room_id, until, include) + .await } - pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { + pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { let root_id = self .services .timeline - .get_pdu_id(root_event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; + .get_pdu_id(root_event_id) + .await + .map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?; let root_pdu = self .services .timeline - .get_pdu_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?; let mut root_pdu_json = self .services .timeline - .get_pdu_json_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_json_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?; if let CanonicalJsonValue::Object(unsigned) = root_pdu_json .entry("unsigned".to_owned()) @@ -103,11 +108,12 @@ impl Service { self.services .timeline - .replace_pdu(&root_id, &root_pdu_json, &root_pdu)?; + .replace_pdu(&root_id, &root_pdu_json, &root_pdu) + .await?; } let mut users = Vec::new(); - if let Some(userids) = self.db.get_participants(&root_id)? { + if let Ok(userids) = self.db.get_participants(&root_id).await { users.extend_from_slice(&userids); } else { users.push(root_pdu.sender); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 2f0c8f258..cd746be43 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,12 +1,20 @@ use std::{ collections::{hash_map, HashMap}, mem::size_of, - sync::{Arc, Mutex}, + sync::Arc, }; -use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result}; -use database::{Database, Map}; -use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::{ + err, expected, + result::{LogErr, NotFound}, + utils, + utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, + Err, PduCount, PduEvent, Result, +}; +use database::{Database, Deserialized, KeyVal, Map}; +use futures::{FutureExt, Stream, StreamExt}; +use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use tokio::sync::Mutex; use crate::{rooms, Dep}; @@ -25,8 +33,7 @@ struct Services { short: Dep, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box + 'a>; +pub type PdusIterItem = (PduCount, PduEvent); type LastTimelineCountCache = Mutex>; impl Data { @@ -46,23 +53,20 @@ impl Data { } } - pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + pub(super) async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache .lock() - .expect("locked") + .await .entry(room_id.to_owned()) { hash_map::Entry::Vacant(v) => { if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max())? - .find_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) { + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .next() + .await + { Ok(*v.insert(last_count.0)) } else { Ok(PduCount::Normal(0)) @@ -73,232 +77,215 @@ impl Data { } /// Returns the `count` of this pdu's id. - pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result> { + pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result { self.eventid_pduid - .get(event_id.as_bytes())? + .qry(event_id) + .await .map(|pdu_id| pdu_count(&pdu_id)) - .transpose() } /// Returns the json of a pdu. - pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.get_non_outlier_pdu_json(event_id)?.map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - ) + pub(super) async fn get_pdu_json(&self, event_id: &EventId) -> Result { + if let Ok(pdu) = self.get_non_outlier_pdu_json(event_id).await { + return Ok(pdu); + } + + self.eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() } /// Returns the json of a pdu. - pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.qry(&pduid).await.deserialized_json() } /// Returns the pdu's id. #[inline] - pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) + pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result> { + self.eventid_pduid.qry(event_id).await } /// Returns the pdu directly from `eventid_pduid` only. - pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.qry(&pduid).await.deserialized_json() + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.qry(&pduid).await?; + + Ok(()) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub(super) fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(pdu) = self - .get_non_outlier_pdu(event_id)? - .map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - )? - .map(Arc::new) - { - Ok(Some(pdu)) - } else { - Ok(None) + pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result> { + if let Ok(pdu) = self.get_non_outlier_pdu(event_id).await { + return Ok(Arc::new(pdu)); } + + self.eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() + .map(Arc::new) + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + self.eventid_outlierpdu.qry(event_id).await?; + + Ok(()) + } + + /// Like get_pdu(), but without the expense of fetching and parsing the data + pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool { + let non_outlier = self.non_outlier_pdu_exists(event_id).map(|res| res.is_ok()); + let outlier = self.outlier_pdu_exists(event_id).map(|res| res.is_ok()); + + //TODO: parallelize + non_outlier.await || outlier.await } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { + self.pduid_pdu.qry(pdu_id).await.deserialized_json() } /// Returns the pdu as a `BTreeMap`. - pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + self.pduid_pdu.qry(pdu_id).await.deserialized_json() } - pub(super) fn append_pdu( - &self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64, - ) -> Result<()> { + pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; + ); self.lasttimelinecount_cache .lock() - .expect("locked") + .await .insert(pdu.room_id.clone(), PduCount::Normal(count)); - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - - Ok(()) + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id); + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes()); } - pub(super) fn prepend_backfill_pdu( - &self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject, - ) -> Result<()> { + pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; - - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(event_id.as_bytes())?; + ); - Ok(()) + self.eventid_pduid.insert(event_id.as_bytes(), pdu_id); + self.eventid_outlierpdu.remove(event_id.as_bytes()); } /// Removes a pdu and creates a new one with the same id. - pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); + pub(super) async fn replace_pdu( + &self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, + ) -> Result<()> { + if self.pduid_pdu.qry(pdu_id).await.is_not_found() { + return Err!(Request(NotFound("PDU does not exist."))); } + let pdu = serde_json::to_vec(pdu_json)?; + self.pduid_pdu.insert(pdu_id, &pdu); + Ok(()) } /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result> { - let (prefix, current) = self.count_to_id(room_id, until, 1, true)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + pub(super) async fn pdus_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + ) -> Result + Send + 'a> { + let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?; + let stream = self + .pduid_pdu + .rev_raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); + + Ok(stream) } - pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result> { - let (prefix, current) = self.count_to_id(room_id, from, 1, false)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + pub(super) async fn pdus_after<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + ) -> Result + Send + 'a> { + let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?; + let stream = self + .pduid_pdu + .raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); + + Ok(stream) + } + + fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem { + let mut pdu = + serde_json::from_slice::(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON"); + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + pdu.add_age().log_err().ok(); + let count = pdu_count(pdu_id); + + (count, pdu) } pub(super) fn increment_notification_counts( &self, room_id: &RoomId, notifies: Vec, highlights: Vec, - ) -> Result<()> { - let mut notifies_batch = Vec::new(); - let mut highlights_batch = Vec::new(); + ) { + let _cork = self.db.cork(); + for user in notifies { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - notifies_batch.push(userroom_id); + increment(&self.userroomid_notificationcount, &userroom_id); } + for user in highlights { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - highlights_batch.push(userroom_id); + increment(&self.userroomid_highlightcount, &userroom_id); } - - self.userroomid_notificationcount - .increment_batch(notifies_batch.iter().map(Vec::as_slice))?; - self.userroomid_highlightcount - .increment_batch(highlights_batch.iter().map(Vec::as_slice))?; - Ok(()) } - pub(super) fn count_to_id( + pub(super) async fn count_to_id( &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, ) -> Result<(Vec, Vec)> { let prefix = self .services .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .get_shortroomid(room_id) + .await + .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))? .to_be_bytes() .to_vec(); + let mut pdu_id = prefix.clone(); // +1 so we don't send the base event let count_raw = match count { @@ -326,17 +313,23 @@ impl Data { } /// Returns the `count` of this pdu's id. -pub(super) fn pdu_count(pdu_id: &[u8]) -> Result { - let stride = size_of::(); +pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount { + const STRIDE: usize = size_of::(); + let pdu_id_len = pdu_id.len(); - let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; - let second_last_u64 = - utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]); + let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]); + let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]); - if matches!(second_last_u64, Ok(0)) { - Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) + if second_last_u64 == 0 { + PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)) } else { - Ok(PduCount::Normal(last_u64)) + PduCount::Normal(last_u64) } } + +//TODO: this is an ABA +fn increment(db: &Arc, key: &[u8]) { + let old = db.get(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, &new); +} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 04d9559da..5360d2c96 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,19 +1,20 @@ mod data; use std::{ + cmp, collections::{BTreeMap, HashSet}, fmt::Write, sync::Arc, }; use conduit::{ - debug, error, info, + debug, err, error, info, pdu::{EventHash, PduBuilder, PduCount, PduEvent}, utils, - utils::{MutexMap, MutexMapGuard}, - validated, warn, Error, Result, Server, + utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, + validated, warn, Err, Error, Result, Server, }; -use itertools::Itertools; +use futures::{future, future::ready, Future, Stream, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, @@ -39,6 +40,7 @@ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; use self::data::Data; +pub use self::data::PdusIterItem; use crate::{ account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, rooms::state_compressor::CompressedStateEvent, sending, server_keys, Dep, @@ -129,6 +131,7 @@ impl crate::Service for Service { } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + /* let lasttimelinecount_cache = self .db .lasttimelinecount_cache @@ -136,6 +139,7 @@ impl crate::Service for Service { .expect("locked") .len(); writeln!(out, "lasttimelinecount_cache: {lasttimelinecount_cache}")?; + */ let mutex_insert = self.mutex_insert.len(); writeln!(out, "insert_mutex: {mutex_insert}")?; @@ -144,11 +148,13 @@ impl crate::Service for Service { } fn clear_cache(&self) { + /* self.db .lasttimelinecount_cache .lock() .expect("locked") .clear(); + */ } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -156,28 +162,32 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self), level = "debug")] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + pub async fn first_pdu_in_room(&self, room_id: &RoomId) -> Result> { + self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - self.all_pdus(user_id!("@placeholder:conduwuit.placeholder"), room_id)? - .last() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result> { + self.pdus_until(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) + .await? + .next() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - self.db.last_timeline_count(sender_user, room_id) + pub async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + self.db.last_timeline_count(sender_user, room_id).await } /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.db.get_pdu_count(event_id) } + pub async fn get_pdu_count(&self, event_id: &EventId) -> Result { self.db.get_pdu_count(event_id).await } // TODO Is this the same as the function above? /* @@ -203,49 +213,56 @@ impl Service { */ /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_json(event_id) + pub async fn get_pdu_json(&self, event_id: &EventId) -> Result { + self.db.get_pdu_json(event_id).await } /// Returns the json of a pdu. #[inline] - pub fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_non_outlier_pdu_json(event_id) + pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { + self.db.get_non_outlier_pdu_json(event_id).await } /// Returns the pdu's id. #[inline] - pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.db.get_pdu_id(event_id) + pub async fn get_pdu_id(&self, event_id: &EventId) -> Result> { + self.db.get_pdu_id(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. #[inline] - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.db.get_non_outlier_pdu(event_id) + pub async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { + self.db.get_non_outlier_pdu(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result>> { self.db.get_pdu(event_id) } + pub async fn get_pdu(&self, event_id: &EventId) -> Result> { self.db.get_pdu(event_id).await } + + /// Checks if pdu exists + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> impl Future + Send + 'a { + self.db.pdu_exists(event_id) + } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.db.get_pdu_from_id(pdu_id) } + pub async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { self.db.get_pdu_from_id(pdu_id).await } /// Returns the pdu as a `BTreeMap`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.db.get_pdu_json_from_id(pdu_id) + pub async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + self.db.get_pdu_json_from_id(pdu_id).await } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self), level = "debug")] - pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { - self.db.replace_pdu(pdu_id, pdu_json, pdu) + pub async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + self.db.replace_pdu(pdu_id, pdu_json, pdu).await } /// Creates a new persisted data unit and adds it to a room. @@ -268,8 +285,9 @@ impl Service { let shortroomid = self .services .short - .get_shortroomid(&pdu.room_id)? - .expect("room exists"); + .get_shortroomid(&pdu.room_id) + .await + .map_err(|_| err!(Database("Room does not exist")))?; // Make unsigned fields correct. This is not properly documented in the spec, // but state events need to have previous content in the unsigned field, so @@ -279,17 +297,17 @@ impl Service { .entry("unsigned".to_owned()) .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) { - if let Some(shortstatehash) = self + if let Ok(shortstatehash) = self .services .state_accessor .pdu_shortstatehash(&pdu.event_id) - .unwrap() + .await { - if let Some(prev_state) = self + if let Ok(prev_state) = self .services .state_accessor .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() + .await { unsigned.insert( "prev_content".to_owned(), @@ -318,10 +336,12 @@ impl Service { // We must keep track of all events that have been referenced. self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, leaves, state_lock) + .await; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; @@ -330,17 +350,17 @@ impl Service { // appending fails self.services .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1); self.services .user - .reset_notification_counts(&pdu.sender, &pdu.room_id)?; + .reset_notification_counts(&pdu.sender, &pdu.room_id); - let count2 = self.services.globals.next_count()?; + let count2 = self.services.globals.next_count().unwrap(); let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); // Insert pdu - self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; + self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await; drop(insert_lock); @@ -348,12 +368,9 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? + .room_state_get_content(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("invalid m.room.power_levels event"))) .unwrap_or_default(); let sync_pdu = pdu.to_sync_room_event(); @@ -365,7 +382,9 @@ impl Service { .services .state_cache .active_local_users_in_room(&pdu.room_id) - .collect_vec(); + .map(ToOwned::to_owned) + .collect::>() + .await; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { @@ -386,23 +405,20 @@ impl Service { let rules_for_user = self .services .account_data - .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? - .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid push rules event in db for user ID {user}: {e}"); - Error::bad_database("Invalid push rules event in db.") - }) - }) - .transpose()? - .map_or_else(|| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); + .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) + .map_err(|e| err!(Database(warn!(?user, ?e, "Invalid push rules event in db for user")))) + .map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); let mut highlight = false; let mut notify = false; - for action in - self.services - .pusher - .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? + for action in self + .services + .pusher + .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id) + .await? { match action { Action::Notify => notify = true, @@ -421,31 +437,36 @@ impl Service { highlights.push(user.clone()); } - for push_key in self.services.pusher.get_pushkeys(user) { - self.services - .sending - .send_pdu_push(&pdu_id, user, push_key?)?; - } + self.services + .pusher + .get_pushkeys(user) + .ready_for_each(|push_key| { + self.services + .sending + .send_pdu_push(&pdu_id, user, push_key.to_owned()) + .expect("TODO: replace with future"); + }) + .await; } self.db - .increment_notification_counts(&pdu.room_id, notifies, highlights)?; + .increment_notification_counts(&pdu.room_id, notifies, highlights); match pdu.kind { TimelineEventType::RoomRedaction => { use RoomVersionId::*; - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, @@ -457,13 +478,13 @@ impl Service { })?; if let Some(redact_id) = &content.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, @@ -492,7 +513,7 @@ impl Service { let invite_state = match content.membership { MembershipState::Invite => { - let state = self.services.state.calculate_invite_state(pdu)?; + let state = self.services.state.calculate_invite_state(pdu).await?; Some(state) }, _ => None, @@ -500,15 +521,18 @@ impl Service { // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - self.services.state_cache.update_membership( - &pdu.room_id, - &target_user_id, - content, - &pdu.sender, - invite_state, - None, - true, - )?; + self.services + .state_cache + .update_membership( + &pdu.room_id, + &target_user_id, + content, + &pdu.sender, + invite_state, + None, + true, + ) + .await?; } }, TimelineEventType::RoomMessage => { @@ -516,9 +540,7 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); if self.services.admin.is_admin_command(pdu, &body).await { self.services @@ -531,10 +553,10 @@ impl Service { } if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? { + if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(PduCount::Normal(count2), related_pducount); } } @@ -545,14 +567,17 @@ impl Service { } => { // We need to do it again here, because replies don't have // event_id as a top level field - if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? { + if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(PduCount::Normal(count2), related_pducount); } }, Relation::Thread(thread) => { - self.services.threads.add_to_thread(&thread.event_id, pdu)?; + self.services + .threads + .add_to_thread(&thread.event_id, pdu) + .await?; }, _ => {}, // TODO: Aggregate other types } @@ -562,7 +587,8 @@ impl Service { if self .services .state_cache - .appservice_in_room(&pdu.room_id, appservice)? + .appservice_in_room(&pdu.room_id, appservice) + .await { self.services .sending @@ -596,15 +622,14 @@ impl Service { .as_ref() .map_or(false, |state_key| users.is_match(state_key)) }; - let matching_aliases = |aliases: &NamespaceRegex| { + let matching_aliases = |aliases: NamespaceRegex| { self.services .alias .local_aliases_for_room(&pdu.room_id) - .filter_map(Result::ok) - .any(|room_alias| aliases.is_match(room_alias.as_str())) + .ready_any(move |room_alias| aliases.is_match(room_alias.as_str())) }; - if matching_aliases(&appservice.aliases) + if matching_aliases(appservice.aliases.clone()).await || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { @@ -617,7 +642,7 @@ impl Service { Ok(pdu_id) } - pub fn create_hash_and_sign_event( + pub async fn create_hash_and_sign_event( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -636,47 +661,59 @@ impl Service { let prev_events: Vec<_> = self .services .state - .get_forward_extremities(room_id)? - .into_iter() + .get_forward_extremities(room_id) .take(20) - .collect(); + .map(Arc::from) + .collect() + .await; // If there was no create event yet, assume we are creating a room - let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::(content.get()) - .expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; + let room_version_id = self + .services + .state + .get_room_version(room_id) + .await + .or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + let content = serde_json::from_str::(content.get()) + .expect("Invalid content in RoomCreate pdu."); + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) + } + })?; let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - let auth_events = - self.services - .state - .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + let auth_events = self + .services + .state + .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content) + .await?; // Our depth is the maximum depth of prev_events + 1 let depth = prev_events .iter() - .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) + .stream() + .map(Ok) + .and_then(|event_id| self.get_pdu(event_id)) + .and_then(|pdu| future::ok(pdu.depth)) + .ignore_err() + .ready_fold(uint!(0), cmp::max) + .await .saturating_add(uint!(1)); let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { - if let Some(prev_pdu) = - self.services - .state_accessor - .room_state_get(room_id, &event_type.to_string().into(), state_key)? + if let Ok(prev_pdu) = self + .services + .state_accessor + .room_state_get(room_id, &event_type.to_string().into(), state_key) + .await { unsigned.insert( "prev_content".to_owned(), @@ -727,19 +764,22 @@ impl Service { signatures: None, }; + let auth_fetch = |k: &StateEventType, s: &str| { + let key = (k.clone(), s.to_owned()); + ready(auth_events.get(&key)) + }; + let auth_check = state_res::auth_check( &room_version, &pdu, - None::, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), + None, // TODO: third_party_invite + auth_fetch, ) - .map_err(|e| { - error!("Auth check failed: {:?}", e); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed.") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Event is not authorized.")); + return Err!(Request(Forbidden("Event is not authorized."))); } // Hash and sign @@ -795,7 +835,8 @@ impl Service { let _shorteventid = self .services .short - .get_or_create_shorteventid(&pdu.event_id)?; + .get_or_create_shorteventid(&pdu.event_id) + .await; Ok((pdu, pdu_json)) } @@ -811,108 +852,117 @@ impl Service { room_id: &RoomId, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { - let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = self.services.admin.get_admin_room()? { - if admin_room == room_id { - match pdu.event_type() { - TimelineEventType::RoomEncryption => { - warn!("Encryption is not allowed in the admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Encryption is not allowed in the admins room", - )); - }, - TimelineEventType::RoomMember => { - let target = pdu - .state_key() - .filter(|v| v.starts_with('@')) - .unwrap_or(sender.as_str()); - let server_user = &self.services.globals.server_user.to_string(); - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu"))?; - - if content.membership == MembershipState::Leave { - if target == server_user { - warn!("Server user cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot leave from admins room.", - )); - } + let (pdu, pdu_json) = self + .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) + .await?; - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot leave from admins room.", - )); - } + if self.services.admin.is_admin_room(&pdu.room_id).await { + match pdu.event_type() { + TimelineEventType::RoomEncryption => { + warn!("Encryption is not allowed in the admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Encryption is not allowed in the admins room", + )); + }, + TimelineEventType::RoomMember => { + let target = pdu + .state_key() + .filter(|v| v.starts_with('@')) + .unwrap_or(sender.as_str()); + let server_user = &self.services.globals.server_user.to_string(); + + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu"))?; + + if content.membership == MembershipState::Leave { + if target == server_user { + warn!("Server user cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Server user cannot leave from admins room.", + )); } - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { - if target == server_user { - warn!("Server user cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot be banned in admins room.", - )); - } + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + warn!("Last admin cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Last admin cannot leave from admins room.", + )); + } + } - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot be banned in admins room.", - )); - } + if content.membership == MembershipState::Ban && pdu.state_key().is_some() { + if target == server_user { + warn!("Server user cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Server user cannot be banned in admins room.", + )); } - }, - _ => {}, - } + + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + warn!("Last admin cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Last admin cannot be banned in admins room.", + )); + } + } + }, + _ => {}, } } // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { use RoomVersionId::*; - match self.services.state.get_room_version(&pdu.room_id)? { + match self.services.state.get_room_version(&pdu.room_id).await? { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } }; }, _ => { let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + .map_err(|e| err!(Database("Invalid content in redaction pdu: {e:?}")))?; if let Some(redact_id) = &content.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } } }, @@ -922,7 +972,7 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehashid = self.services.state.append_to_state(&pdu)?; + let statehashid = self.services.state.append_to_state(&pdu).await?; let pdu_id = self .append_pdu( @@ -939,14 +989,15 @@ impl Service { // in time where events in the current room state do not exist self.services .state - .set_room_state(room_id, statehashid, state_lock)?; + .set_room_state(&pdu.room_id, statehashid, state_lock); let mut servers: HashSet = self .services .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .room_servers(&pdu.room_id) + .map(ToOwned::to_owned) + .collect() + .await; // In case we are kicking or banning a user, we need to inform their server of // the change @@ -966,7 +1017,8 @@ impl Service { self.services .sending - .send_pdu_servers(servers.into_iter(), &pdu_id)?; + .send_pdu_servers(servers.iter().map(AsRef::as_ref).stream(), &pdu_id) + .await?; Ok(pdu.event_id) } @@ -988,15 +1040,19 @@ impl Service { // fail. self.services .state - .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; + .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed) + .await?; if soft_fail { self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock) + .await; + return Ok(None); } @@ -1009,71 +1065,71 @@ impl Service { /// Returns an iterator over all PDUs in a room. #[inline] - pub fn all_pdus<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, - ) -> Result> + 'a> { - self.pdus_after(user_id, room_id, PduCount::min()) + pub async fn all_pdus<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, + ) -> Result + Send + 'a> { + self.pdus_after(user_id, room_id, PduCount::min()).await } /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_until<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a> { - self.db.pdus_until(user_id, room_id, until) + pub async fn pdus_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + ) -> Result + Send + 'a> { + self.db.pdus_until(user_id, room_id, until).await } /// Returns an iterator over all events and their token in a room that /// happened after the event with id `from` in chronological order. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_after<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a> { - self.db.pdus_after(user_id, room_id, from) + pub async fn pdus_after<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + ) -> Result + Send + 'a> { + self.db.pdus_after(user_id, room_id, from).await } /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { + pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { // TODO: Don't reserialize, keep original json - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + let Ok(pdu_id) = self.get_pdu_id(event_id).await else { + // If event does not exist, just noop + return Ok(()); + }; - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(body) = content.body { - self.services - .search - .deindex_pdu(shortroomid, &pdu_id, &body)?; - } + let mut pdu = self + .get_pdu_from_id(&pdu_id) + .await + .map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?; + + if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Some(body) = content.body { + self.services + .search + .deindex_pdu(shortroomid, &pdu_id, &body); } + } - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; - pdu.redact(room_version_id, reason)?; + pdu.redact(room_version_id, reason)?; - self.replace_pdu( - &pdu_id, - &utils::to_canonical_object(&pdu).map_err(|e| { - error!("Failed to convert PDU to canonical JSON: {}", e); - Error::bad_database("Failed to convert PDU to canonical JSON.") - })?, - &pdu, - )?; - } - // If event does not exist, just noop - Ok(()) + let obj = utils::to_canonical_object(&pdu) + .map_err(|e| err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))))?; + + self.replace_pdu(&pdu_id, &obj, &pdu).await } #[tracing::instrument(skip(self))] pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { let first_pdu = self - .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .expect("Room is not empty")?; + .await + .expect("Room is not empty"); if first_pdu.0 < from { // No backfill required, there are still events between them @@ -1083,17 +1139,18 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await .map(|ev| { serde_json::from_str(ev.content.get()) .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + .unwrap() }) - .transpose()? .unwrap_or_default(); let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) { - Some(user_id.server_name().to_owned()) + Some(user_id.server_name()) } else { None } @@ -1103,34 +1160,43 @@ impl Service { .services .alias .local_aliases_for_room(room_id) - .filter_map(|alias| { - alias - .ok() - .filter(|alias| !self.services.globals.server_is_ours(alias.server_name())) - .map(|alias| alias.server_name().to_owned()) + .ready_filter_map(|alias| { + self.services + .globals + .server_is_ours(alias.server_name()) + .then_some(alias.server_name()) }); - let servers = room_mods + let mut servers = room_mods + .stream() .chain(room_alias_servers) - .chain(self.services.server.config.trusted_servers.clone()) - .filter(|server_name| { - if self.services.globals.server_is_ours(server_name) { - return false; - } - + .map(ToOwned::to_owned) + .chain( + self.services + .server + .config + .trusted_servers + .iter() + .map(ToOwned::to_owned) + .stream(), + ) + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)) + .filter_map(|server_name| async move { self.services .state_cache - .server_in_room(server_name, room_id) - .unwrap_or(false) - }); + .server_in_room(&server_name, room_id) + .await + .then_some(server_name) + }) + .boxed(); - for backfill_server in servers { + while let Some(ref backfill_server) = servers.next().await { info!("Asking {backfill_server} for backfill"); let response = self .services .sending .send_federation_request( - &backfill_server, + backfill_server, federation::backfill::get_backfill::v1::Request { room_id: room_id.to_owned(), v: vec![first_pdu.1.event_id.as_ref().to_owned()], @@ -1142,7 +1208,7 @@ impl Service { Ok(response) => { let pub_key_map = RwLock::new(BTreeMap::new()); for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(&backfill_server, pdu, &pub_key_map).await { + if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await { warn!("Failed to add backfilled pdu in room {room_id}: {e}"); } } @@ -1163,7 +1229,7 @@ impl Service { &self, origin: &ServerName, pdu: Box, pub_key_map: &RwLock>>, ) -> Result<()> { - let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?; + let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?; // Lock so we cannot backfill the same pdu twice at the same time let mutex_lock = self @@ -1174,7 +1240,7 @@ impl Service { .await; // Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.get_pdu_id(&event_id)? { + if let Ok(pdu_id) = self.get_pdu_id(&event_id).await { let pdu_id = pdu_id.to_vec(); debug!("We already know {event_id} at {pdu_id:?}"); return Ok(()); @@ -1190,36 +1256,38 @@ impl Service { .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) .await?; - let value = self.get_pdu_json(&event_id)?.expect("We just created it"); - let pdu = self.get_pdu(&event_id)?.expect("We just created it"); + let value = self + .get_pdu_json(&event_id) + .await + .expect("We just created it"); + let pdu = self.get_pdu(&event_id).await.expect("We just created it"); let shortroomid = self .services .short - .get_shortroomid(&room_id)? + .get_shortroomid(&room_id) + .await .expect("room exists"); let insert_lock = self.mutex_insert.lock(&room_id).await; let max = u64::MAX; - let count = self.services.globals.next_count()?; + let count = self.services.globals.next_count().unwrap(); let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&(validated!(max - count)).to_be_bytes()); // Insert pdu - self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; + self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value); drop(insert_lock); if pdu.kind == TimelineEventType::RoomMessage { let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + .map_err(|e| err!(Database("Invalid content in pdu: {e:?}")))?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); } } drop(mutex_lock); diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 3cf1cdd59..bcfce6168 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -46,7 +46,7 @@ impl Service { /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { - debug_info!("typing started {:?} in {:?} timeout:{:?}", user_id, room_id, timeout); + debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}"); // update clients self.typing .write() @@ -54,17 +54,19 @@ impl Service { .entry(room_id.to_owned()) .or_default() .insert(user_id.to_owned(), timeout); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, true)?; + self.federation_send(room_id, user_id, true).await?; } Ok(()) @@ -72,7 +74,7 @@ impl Service { /// Removes a user from typing before the timeout is reached. pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - debug_info!("typing stopped {:?} in {:?}", user_id, room_id); + debug_info!("typing stopped {user_id:?} in {room_id:?}"); // update clients self.typing .write() @@ -80,31 +82,31 @@ impl Service { .entry(room_id.to_owned()) .or_default() .remove(user_id); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, false)?; + self.federation_send(room_id, user_id, false).await?; } Ok(()) } - pub async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> { + pub async fn wait_for_update(&self, room_id: &RoomId) { let mut receiver = self.typing_update_sender.subscribe(); while let Ok(next) = receiver.recv().await { if next == room_id { break; } } - - Ok(()) } /// Makes sure that typing events with old timestamps get removed. @@ -123,30 +125,30 @@ impl Service { removable.push(user.clone()); } } - - drop(typing); }; if !removable.is_empty() { let typing = &mut self.typing.write().await; let room = typing.entry(room_id.to_owned()).or_default(); for user in &removable { - debug_info!("typing timeout {:?} in {:?}", &user, room_id); + debug_info!("typing timeout {user:?} in {room_id:?}"); room.remove(user); } + // update clients self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation - for user in removable { - if self.services.globals.user_is_local(&user) { - self.federation_send(room_id, &user, false)?; + for user in &removable { + if self.services.globals.user_is_local(user) { + self.federation_send(room_id, user, false).await?; } } } @@ -183,7 +185,7 @@ impl Service { }) } - fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { + async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { debug_assert!( self.services.globals.user_is_local(user_id), "tried to broadcast typing status of remote user", @@ -197,7 +199,8 @@ impl Service { self.services .sending - .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?; + .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing")) + .await?; Ok(()) } diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index c71316153..d4d9874c2 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,8 +1,9 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::Result; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; +use ruma::{RoomId, UserId}; use crate::{globals, rooms, Dep}; @@ -11,13 +12,13 @@ pub(super) struct Data { userroomid_highlightcount: Arc, roomuserid_lastnotificationread: Arc, roomsynctoken_shortstatehash: Arc, - userroomid_joined: Arc, services: Services, } struct Services { globals: Dep, short: Dep, + state_cache: Dep, } impl Data { @@ -28,15 +29,15 @@ impl Data { userroomid_highlightcount: db["userroomid_highlightcount"].clone(), roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), services: Services { globals: args.depend::("globals"), short: args.depend::("rooms::short"), + state_cache: args.depend::("rooms::state_cache"), }, } } - pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -45,128 +46,73 @@ impl Data { roomuser_id.extend_from_slice(user_id.as_bytes()); self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + .insert(&userroom_id, &0_u64.to_be_bytes()); self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + .insert(&userroom_id, &0_u64.to_be_bytes()); self.roomuserid_lastnotificationread - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); } - pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - + pub(super) async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); self.userroomid_notificationcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - + pub(super) async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); self.userroomid_highlightcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastnotificationread - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) + pub(super) async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.roomuserid_lastnotificationread + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn associate_token_shortstatehash( - &self, room_id: &RoomId, token: u64, shortstatehash: u64, - ) -> Result<()> { + pub(super) async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { let shortroomid = self .services .short - .get_shortroomid(room_id)? + .get_shortroomid(room_id) + .await .expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(&token.to_be_bytes()); self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) + .insert(&key, &shortstatehash.to_be_bytes()); } - pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = self - .services - .short - .get_shortroomid(room_id)? - .expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); + pub(super) async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { + let shortroomid = self.services.short.get_shortroomid(room_id).await?; + let key: &[u64] = &[shortroomid, token]; self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")) - }) - .transpose() + .qry(key) + .await + .deserialized() } + //TODO: optimize; replace point-queries with dual iteration pub(super) fn get_shared_rooms<'a>( - &'a self, users: Vec, - ) -> Result> + 'a>> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xFF) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - .saturating_add(1); // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(Result::ok) - }); - - // We use the default compare function because keys are sorted correctly (not - // reversed) - Ok(Box::new( - utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, - ) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }), - )) + &'a self, user_a: &'a UserId, user_b: &'a UserId, + ) -> impl Stream + Send + 'a { + self.services + .state_cache + .rooms_joined(user_a) + .filter(|room_id| self.services.state_cache.is_joined(user_b, room_id)) } } diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 93d38470f..d9d90ecf9 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -3,7 +3,8 @@ mod data; use std::sync::Arc; use conduit::Result; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use futures::{pin_mut, Stream, StreamExt}; +use ruma::{RoomId, UserId}; use self::data::Data; @@ -22,32 +23,49 @@ impl crate::Service for Service { } impl Service { - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.reset_notification_counts(user_id, room_id) + #[inline] + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { + self.db.reset_notification_counts(user_id, room_id); } - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.notification_count(user_id, room_id) + #[inline] + pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.notification_count(user_id, room_id).await } - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.highlight_count(user_id, room_id) + #[inline] + pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.highlight_count(user_id, room_id).await } - pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_notification_read(user_id, room_id) + #[inline] + pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.last_notification_read(user_id, room_id).await } - pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { + #[inline] + pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { self.db .associate_token_shortstatehash(room_id, token, shortstatehash) + .await; } - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - self.db.get_token_shortstatehash(room_id, token) + #[inline] + pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { + self.db.get_token_shortstatehash(room_id, token).await } - pub fn get_shared_rooms(&self, users: Vec) -> Result> + '_> { - self.db.get_shared_rooms(users) + #[inline] + pub fn get_shared_rooms<'a>( + &'a self, user_a: &'a UserId, user_b: &'a UserId, + ) -> impl Stream + Send + 'a { + self.db.get_shared_rooms(user_a, user_b) + } + + pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool { + let get_shared_rooms = self.get_shared_rooms(user_a, user_b); + + pin_mut!(get_shared_rooms); + get_shared_rooms.next().await.is_some() } } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 6c8e2544d..b96f9a03c 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,14 +1,21 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Database, Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; use crate::{globals, Dep}; -type OutgoingSendingIter<'a> = Box, Destination, SendingEvent)>> + 'a>; -type SendingEventIter<'a> = Box, SendingEvent)>> + 'a>; +pub(super) type OutgoingItem = (Key, SendingEvent, Destination); +pub(super) type SendingItem = (Key, SendingEvent); +pub(super) type QueueItem = (Key, SendingEvent); +pub(super) type Key = Vec; pub struct Data { servercurrentevent_data: Arc, @@ -36,58 +43,82 @@ impl Data { } } - #[inline] - pub fn active_requests(&self) -> OutgoingSendingIter<'_> { - Box::new( - self.servercurrentevent_data - .iter() - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), - ) + pub(super) fn delete_active_request(&self, key: &[u8]) { self.servercurrentevent_data.remove(key); } + + pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) { + let prefix = destination.get_prefix(); + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; } - #[inline] - pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> { + pub(super) async fn delete_all_requests_for(&self, destination: &Destination) { let prefix = destination.get_prefix(); - Box::new( - self.servercurrentevent_data - .scan_prefix(prefix) - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), - ) + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; + + self.servernameevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servernameevent_data.remove(key)) + .await; } - pub(super) fn delete_active_request(&self, key: &[u8]) -> Result<()> { self.servercurrentevent_data.remove(key) } + pub(super) fn mark_as_active(&self, events: &[QueueItem]) { + for (key, e) in events { + if key.is_empty() { + continue; + } - pub(super) fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { - let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { - self.servercurrentevent_data.remove(&key)?; + let value = if let SendingEvent::Edu(value) = &e { + &**value + } else { + &[] + }; + self.servercurrentevent_data.insert(key, value); + self.servernameevent_data.remove(key); } + } - Ok(()) + #[inline] + pub fn active_requests(&self) -> impl Stream + Send + '_ { + self.servercurrentevent_data + .raw_stream() + .ignore_err() + .map(|(key, val)| { + let (dest, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); + + (key.to_vec(), event, dest) + }) } - pub(super) fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { + #[inline] + pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> impl Stream + Send + 'a { let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } - - for (key, _) in self.servernameevent_data.scan_prefix(prefix) { - self.servernameevent_data.remove(&key).unwrap(); - } + self.servercurrentevent_data + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); - Ok(()) + (key.to_vec(), event) + }) } - pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>> { + pub(super) fn queue_requests(&self, requests: &[(&SendingEvent, &Destination)]) -> Vec> { let mut batch = Vec::new(); let mut keys = Vec::new(); - for (destination, event) in requests { + for (event, destination) in requests { let mut key = destination.get_prefix(); if let SendingEvent::Pdu(value) = &event { key.extend_from_slice(value); } else { - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); } let value = if let SendingEvent::Edu(value) = &event { &**value @@ -97,56 +128,39 @@ impl Data { batch.push((key.clone(), value.to_owned())); keys.push(key); } - self.servernameevent_data - .insert_batch(batch.iter().map(database::KeyVal::from))?; - Ok(keys) - } - pub fn queued_requests<'a>( - &'a self, destination: &Destination, - ) -> Box)>> + 'a> { - let prefix = destination.get_prefix(); - return Box::new( - self.servernameevent_data - .scan_prefix(prefix) - .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), - ); + self.servernameevent_data.insert_batch(batch.iter()); + keys } - pub(super) fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()> { - for (e, key) in events { - if key.is_empty() { - continue; - } - - let value = if let SendingEvent::Edu(value) = &e { - &**value - } else { - &[] - }; - self.servercurrentevent_data.insert(key, value)?; - self.servernameevent_data.remove(key)?; - } + pub fn queued_requests<'a>(&'a self, destination: &Destination) -> impl Stream + Send + 'a { + let prefix = destination.get_prefix(); + self.servernameevent_data + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); - Ok(()) + (key.to_vec(), event) + }) } - pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) { self.servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()) + .insert(server_name.as_bytes(), &last_count.to_be_bytes()); } - pub fn get_latest_educount(&self, server_name: &ServerName) -> Result { + pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 { self.servername_educount - .get(server_name.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - }) + .qry(server_name) + .await + .deserialized() + .unwrap_or(0) } } #[tracing::instrument(skip(key), level = "debug")] -fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent)> { +fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, SendingEvent)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { let mut parts = key[1..].splitn(2, |&b| b == 0xFF); @@ -164,7 +178,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, if value.is_empty() { SendingEvent::Pdu(event.to_vec()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else if key.starts_with(b"$") { @@ -192,7 +206,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent::Pdu(event.to_vec()) } else { // I'm pretty sure this should never be called - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else { @@ -214,7 +228,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, if value.is_empty() { SendingEvent::Pdu(event.to_vec()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) }) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index b90ea3618..e3582f2ea 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -7,10 +7,11 @@ mod sender; use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; -use conduit::{err, warn, Result, Server}; +use conduit::{err, utils::ReadyExt, warn, Result, Server}; +use futures::{future::ready, Stream, StreamExt, TryStreamExt}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, - OwnedServerName, RoomId, ServerName, UserId, + RoomId, ServerName, UserId, }; use tokio::sync::Mutex; @@ -104,7 +105,7 @@ impl Service { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -117,7 +118,7 @@ impl Service { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -126,30 +127,31 @@ impl Service { } #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] - pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_pdu_servers(servers, pdu_id) + self.send_pdu_servers(servers, pdu_id).await } #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] - pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned()))) - .collect::>(); + pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &[u8]) -> Result<()> + where + S: Stream + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.into()))) + .collect::>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::>()); + for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { dest, @@ -166,7 +168,7 @@ impl Service { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -175,30 +177,30 @@ impl Service { } #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] - pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { + pub async fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_edu_servers(servers, serialized) + self.send_edu_servers(servers, serialized).await } #[tracing::instrument(skip(self, servers, serialized), level = "debug")] - pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone()))) - .collect::>(); + pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: Vec) -> Result<()> + where + S: Stream + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone()))) + .collect::>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::>()); for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { @@ -212,29 +214,33 @@ impl Service { } #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { + pub async fn flush_room(&self, room_id: &RoomId) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.flush_servers(servers) + self.flush_servers(servers).await } #[tracing::instrument(skip(self, servers), level = "debug")] - pub fn flush_servers>(&self, servers: I) -> Result<()> { - let requests = servers.into_iter().map(Destination::Normal); - for dest in requests { - self.dispatch(Msg { - dest, - event: SendingEvent::Flush, - queue_id: Vec::::new(), - })?; - } - - Ok(()) + pub async fn flush_servers<'a, S>(&self, servers: S) -> Result<()> + where + S: Stream + Send + 'a, + { + servers + .map(ToOwned::to_owned) + .map(Destination::Normal) + .map(Ok) + .try_for_each(|dest| { + ready(self.dispatch(Msg { + dest, + event: SendingEvent::Flush, + queue_id: Vec::::new(), + })) + }) + .await } #[tracing::instrument(skip_all, name = "request")] @@ -263,11 +269,10 @@ impl Service { /// Cleanup event data /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self), level = "debug")] - pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + pub async fn cleanup_events(&self, appservice_id: String) { self.db - .delete_all_requests_for(&Destination::Appservice(appservice_id))?; - - Ok(()) + .delete_all_requests_for(&Destination::Appservice(appservice_id)) + .await; } fn dispatch(&self, msg: Msg) -> Result<()> { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 206bf92bb..4db9922ae 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -7,18 +7,15 @@ use std::{ use base64::{engine::general_purpose, Engine as _}; use conduit::{ - debug, debug_warn, error, trace, - utils::{calculate_hash, math::continue_exponential_backoff_secs}, + debug, debug_warn, err, trace, + utils::{calculate_hash, math::continue_exponential_backoff_secs, ReadyExt}, warn, Error, Result, }; -use federation::transactions::send_transaction_message; -use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; +use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ - api::federation::{ - self, - transactions::edu::{ - DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, - }, + api::federation::transactions::{ + edu::{DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap}, + send_transaction_message, }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, @@ -28,7 +25,7 @@ use ruma::{ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::time::sleep_until; -use super::{appservice, Destination, Msg, SendingEvent, Service}; +use super::{appservice, data::QueueItem, Destination, Msg, SendingEvent, Service}; #[derive(Debug)] enum TransactionStatus { @@ -50,20 +47,20 @@ const CLEANUP_TIMEOUT_MS: u64 = 3500; impl Service { #[tracing::instrument(skip_all, name = "sender")] pub(super) async fn sender(&self) -> Result<()> { - let receiver = self.receiver.lock().await; - let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); + let mut futures: SendingFutures<'_> = FuturesUnordered::new(); + let receiver = self.receiver.lock().await; - self.initial_requests(&futures, &mut statuses); + self.initial_requests(&mut futures, &mut statuses).await; loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { request = receiver.recv_async() => match request { - Ok(request) => self.handle_request(request, &futures, &mut statuses), + Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await, Err(_) => break, }, Some(response) = futures.next() => { - self.handle_response(response, &futures, &mut statuses); + self.handle_response(response, &mut futures, &mut statuses).await; }, } } @@ -72,18 +69,16 @@ impl Service { Ok(()) } - fn handle_response<'a>( - &'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + async fn handle_response<'a>( + &'a self, response: SendingResult, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { match response { - Ok(dest) => self.handle_response_ok(&dest, futures, statuses), - Err((dest, e)) => Self::handle_response_err(dest, futures, statuses, &e), + Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await, + Err((dest, e)) => Self::handle_response_err(dest, statuses, &e), }; } - fn handle_response_err( - dest: Destination, _futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, - ) { + fn handle_response_err(dest: Destination, statuses: &mut CurTransactionStatus, e: &Error) { debug!(dest = ?dest, "{e:?}"); statuses.entry(dest).and_modify(|e| { *e = match e { @@ -94,39 +89,40 @@ impl Service { }); } - fn handle_response_ok<'a>( - &'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_response_ok<'a>( + &'a self, dest: &Destination, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { let _cork = self.db.db.cork(); - self.db - .delete_all_active_requests_for(dest) - .expect("all active requests deleted"); + self.db.delete_all_active_requests_for(dest).await; // Find events that have been added since starting the last request let new_events = self .db .queued_requests(dest) - .filter_map(Result::ok) .take(DEQUEUE_LIMIT) - .collect::>(); + .collect::>() + .await; // Insert any pdus we found if !new_events.is_empty() { - self.db - .mark_as_active(&new_events) - .expect("marked as active"); - let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect(); - futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec))); + self.db.mark_as_active(&new_events); + + let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect(); + futures.push(self.send_events(dest.clone(), new_events_vec).boxed()); } else { statuses.remove(dest); } } - fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { - let iv = vec![(msg.event, msg.queue_id)]; - if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_request<'a>( + &'a self, msg: Msg, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, + ) { + let iv = vec![(msg.queue_id, msg.event)]; + if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await { if !events.is_empty() { - futures.push(Box::pin(self.send_events(msg.dest, events))); + futures.push(self.send_events(msg.dest, events).boxed()); } else { statuses.remove(&msg.dest); } @@ -142,7 +138,7 @@ impl Service { tokio::select! { () = sleep_until(deadline.into()) => break, response = futures.next() => match response { - Some(response) => self.handle_response(response, futures, statuses), + Some(response) => self.handle_response(response, futures, statuses).await, None => return, } } @@ -151,16 +147,17 @@ impl Service { debug_warn!("Leaving with {} unfinished requests...", futures.len()); } - fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn initial_requests<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::>::new(); - for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { + let mut active = self.db.active_requests().boxed(); + + while let Some((key, event, dest)) = active.next().await { let entry = txns.entry(dest.clone()).or_default(); if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep { - warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key)); - self.db - .delete_active_request(&key) - .expect("active request deleted"); + warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key)); + self.db.delete_active_request(&key); } else { entry.push(event); } @@ -169,16 +166,16 @@ impl Service { for (dest, events) in txns { if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); - futures.push(Box::pin(self.send_events(dest.clone(), events))); + futures.push(self.send_events(dest.clone(), events).boxed()); } } } #[tracing::instrument(skip_all, level = "debug")] - fn select_events( + async fn select_events( &self, dest: &Destination, - new_events: Vec<(SendingEvent, Vec)>, // Events we want to send: event and full key + new_events: Vec, // Events we want to send: event and full key statuses: &mut CurTransactionStatus, ) -> Result>> { let (allow, retry) = self.select_events_current(dest.clone(), statuses)?; @@ -195,8 +192,8 @@ impl Service { if retry { self.db .active_requests_for(dest) - .filter_map(Result::ok) - .for_each(|(_, e)| events.push(e)); + .ready_for_each(|(_, e)| events.push(e)) + .await; return Ok(Some(events)); } @@ -204,17 +201,17 @@ impl Service { // Compose the next transaction let _cork = self.db.db.cork(); if !new_events.is_empty() { - self.db.mark_as_active(&new_events)?; - for (e, _) in new_events { + self.db.mark_as_active(&new_events); + for (_, e) in new_events { events.push(e); } } // Add EDU's into the transaction if let Destination::Normal(server_name) = dest { - if let Ok((select_edus, last_count)) = self.select_edus(server_name) { + if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { events.extend(select_edus.into_iter().map(SendingEvent::Edu)); - self.db.set_latest_educount(server_name, last_count)?; + self.db.set_latest_educount(server_name, last_count); } } @@ -248,26 +245,32 @@ impl Service { } #[tracing::instrument(skip_all, level = "debug")] - fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { + async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu - let since = self.db.get_latest_educount(server_name)?; + let since = self.db.get_latest_educount(server_name).await; let mut events = Vec::new(); let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - for room_id in self.services.state_cache.server_rooms(server_name) { - let room_id = room_id?; + let server_rooms = self.services.state_cache.server_rooms(server_name); + + pin_mut!(server_rooms); + while let Some(room_id) = server_rooms.next().await { // Look for device list updates in this room device_list_changes.extend( self.services .users - .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok) - .filter(|user_id| self.services.globals.user_is_local(user_id)), + .keys_changed(room_id.as_str(), since, None) + .ready_filter(|user_id| self.services.globals.user_is_local(user_id)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); if self.server.config.allow_outgoing_read_receipts - && !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? + && !self + .select_edus_receipts(room_id, since, &mut max_edu_count, &mut events) + .await? { break; } @@ -290,19 +293,22 @@ impl Service { } if self.server.config.allow_outgoing_presence { - self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; + self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events) + .await?; } Ok((events, max_edu_count)) } /// Look for presence - fn select_edus_presence( + async fn select_edus_presence( &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, ) -> Result { - // Look for presence updates for this server + let presence_since = self.services.presence.presence_since(since); + + pin_mut!(presence_since); let mut presence_updates = Vec::new(); - for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) { + while let Some((user_id, count, presence_bytes)) = presence_since.next().await { *max_edu_count = cmp::max(count, *max_edu_count); if !self.services.globals.user_is_local(&user_id) { @@ -312,7 +318,8 @@ impl Service { if !self .services .state_cache - .server_sees_user(server_name, &user_id)? + .server_sees_user(server_name, &user_id) + .await { continue; } @@ -320,7 +327,9 @@ impl Service { let presence_event = self .services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; + .from_json_bytes_to_event(&presence_bytes, &user_id) + .await?; + presence_updates.push(PresenceUpdate { user_id, presence: presence_event.content.presence, @@ -346,32 +355,33 @@ impl Service { } /// Look for read receipts in this room - fn select_edus_receipts( + async fn select_edus_receipts( &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec>, ) -> Result { - for r in self + let receipts = self .services .read_receipt - .readreceipts_since(room_id, since) - { - let (user_id, count, read_receipt) = r?; - *max_edu_count = cmp::max(count, *max_edu_count); + .readreceipts_since(room_id, since); + pin_mut!(receipts); + while let Some((user_id, count, read_receipt)) = receipts.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); if !self.services.globals.user_is_local(&user_id) { continue; } let event = serde_json::from_str(read_receipt.json().get()) .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { let mut read = BTreeMap::new(); - let (event_id, mut receipt) = r .content .0 .into_iter() .next() .expect("we only use one event per read receipt"); + let receipt = receipt .remove(&ReceiptType::Read) .expect("our read receipts always set this") @@ -427,24 +437,17 @@ impl Service { async fn send_events_dest_appservice( &self, dest: &Destination, id: &str, events: Vec, ) -> SendingResult { - let mut pdu_jsons = Vec::new(); + let Some(appservice) = self.services.appservice.get_registration(id).await else { + return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration"))))); + }; + let mut pdu_jsons = Vec::new(); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdu_jsons.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), - ) - })? - .to_room_event(), - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdu_jsons.push(pdu.to_room_event()); + } }, SendingEvent::Edu(_) | SendingEvent::Flush => { // Appservices don't need EDUs (?) and flush only; @@ -453,32 +456,24 @@ impl Service { } } + let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Flush => &[], + }) + .collect::>(), + )); + //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); let client = &self.services.client.appservice; match appservice::send_request( client, - self.services - .appservice - .get_registration(id) - .await - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Could not load registration from db."), - ) - })?, + appservice, ruma::api::appservice::event::push_events::v1::Request { events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::>(), - ))) - .into(), + txn_id: txn_id.into(), ephemeral: Vec::new(), to_device: Vec::new(), }, @@ -494,23 +489,17 @@ impl Service { async fn send_events_dest_push( &self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec, ) -> SendingResult { - let mut pdus = Vec::new(); + let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else { + return Err((dest.clone(), err!(Database(error!(?userid, ?pushkey, "Missing pusher"))))); + }; + let mut pdus = Vec::new(); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdus.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Push] Event in servernameevent_data not found in db."), - ) - })?, - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdus.push(pdu); + } }, SendingEvent::Edu(_) | SendingEvent::Flush => { // Push gateways don't need EDUs (?) and flush only; @@ -529,28 +518,22 @@ impl Service { } } - let Some(pusher) = self - .services - .pusher - .get_pusher(userid, pushkey) - .map_err(|e| (dest.clone(), e))? - else { - continue; - }; - let rules_for_user = self .services .account_data .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap_or_default() - .and_then(|event| serde_json::from_str::(event.get()).ok()) - .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); + .await + .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) + .map_or_else( + |_| push::Ruleset::server_default(userid), + |ev: PushRulesEvent| ev.content.global, + ); let unread: UInt = self .services .user .notification_count(userid, &pdu.room_id) - .map_err(|e| (dest.clone(), e))? + .await .try_into() .expect("notification count can't go that high"); @@ -559,7 +542,6 @@ impl Service { .pusher .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) .await - .map(|_response| dest.clone()) .map_err(|e| (dest.clone(), e)); } @@ -586,21 +568,11 @@ impl Service { for event in &events { match event { // TODO: check room version and remove event_id if needed - SendingEvent::Pdu(pdu_id) => pdu_jsons.push( - self.convert_to_outgoing_federation_event( - self.services - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - error!(?dest, ?server, ?pdu_id, "event not found"); - ( - dest.clone(), - Error::bad_database("[Normal] Event in servernameevent_data not found in db."), - ) - })?, - ), - ), + SendingEvent::Pdu(pdu_id) => { + if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await { + pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await); + } + }, SendingEvent::Edu(edu) => { if let Ok(raw) = serde_json::from_slice(edu) { edu_jsons.push(raw); @@ -647,7 +619,7 @@ impl Service { } /// This does not return a full `Pdu` it is only to satisfy ruma's types. - pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box { + pub async fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box { if let Some(unsigned) = pdu_json .get_mut("unsigned") .and_then(|val| val.as_object_mut()) @@ -660,7 +632,7 @@ impl Service { .get("room_id") .and_then(|val| RoomId::parse(val.as_str()?).ok()) { - match self.services.state.get_room_version(&room_id) { + match self.services.state.get_room_version(&room_id).await { Ok(room_version_id) => match room_version_id { RoomVersionId::V1 | RoomVersionId::V2 => {}, _ => _ = pdu_json.remove("event_id"), diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index a565e5009..ae2b8c3cb 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -5,7 +5,7 @@ use std::{ }; use conduit::{debug, debug_error, debug_warn, err, error, info, trace, warn, Err, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::{ discovery::{ @@ -179,7 +179,8 @@ impl Service { let result: BTreeMap<_, _> = self .services .globals - .verify_keys_for(origin)? + .verify_keys_for(origin) + .await? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -236,7 +237,8 @@ impl Service { .services .globals .db - .add_signing_key(&k.server_name, k.clone())? + .add_signing_key(&k.server_name, k.clone()) + .await .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect::>(); @@ -283,7 +285,8 @@ impl Service { .services .globals .db - .add_signing_key(&origin, key)? + .add_signing_key(&origin, key) + .await .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -384,7 +387,8 @@ impl Service { let mut result: BTreeMap<_, _> = self .services .globals - .verify_keys_for(origin)? + .verify_keys_for(origin) + .await? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -431,7 +435,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, k.clone())?; + .add_signing_key(origin, k.clone()) + .await; result.extend( k.verify_keys .into_iter() @@ -462,7 +467,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, server_key.clone())?; + .add_signing_key(origin, server_key.clone()) + .await; result.extend( server_key @@ -495,7 +501,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, server_key.clone())?; + .add_signing_key(origin, server_key.clone()) + .await; result.extend( server_key @@ -545,7 +552,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, k.clone())?; + .add_signing_key(origin, k.clone()) + .await; result.extend( k.verify_keys .into_iter() diff --git a/src/service/services.rs b/src/service/services.rs index 3aa095b85..da22fb2d4 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -14,7 +14,7 @@ use crate::{ manager::Manager, media, presence, pusher, resolver, rooms, sending, server_keys, service, service::{Args, Map, Service}, - transaction_ids, uiaa, updates, users, + sync, transaction_ids, uiaa, updates, users, }; pub struct Services { @@ -32,6 +32,7 @@ pub struct Services { pub rooms: rooms::Service, pub sending: Arc, pub server_keys: Arc, + pub sync: Arc, pub transaction_ids: Arc, pub uiaa: Arc, pub updates: Arc, @@ -96,6 +97,7 @@ impl Services { }, sending: build!(sending::Service), server_keys: build!(server_keys::Service), + sync: build!(sync::Service), transaction_ids: build!(transaction_ids::Service), uiaa: build!(uiaa::Service), updates: build!(updates::Service), diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs new file mode 100644 index 000000000..1bf4610ff --- /dev/null +++ b/src/service/sync/mod.rs @@ -0,0 +1,233 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + sync::{Arc, Mutex, Mutex as StdMutex}, +}; + +use conduit::Result; +use ruma::{ + api::client::sync::sync_events::{ + self, + v4::{ExtensionsConfig, SyncRequestList}, + }, + OwnedDeviceId, OwnedRoomId, OwnedUserId, +}; + +pub struct Service { + connections: DbConnections, +} + +struct SlidingSyncCache { + lists: BTreeMap, + subscriptions: BTreeMap, + known_rooms: BTreeMap>, // For every room, the roomsince number + extensions: ExtensionsConfig, +} + +type DbConnections = Mutex>; +type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); +type DbConnectionsVal = Arc>; + +impl crate::Service for Service { + fn build(_args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + connections: StdMutex::new(BTreeMap::new()), + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { + self.connections + .lock() + .unwrap() + .contains_key(&(user_id, device_id, conn_id)) + } + + pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { + self.connections + .lock() + .expect("locked") + .remove(&(user_id, device_id, conn_id)); + } + + pub fn update_sync_request_with_cache( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, + ) -> BTreeMap> { + let Some(conn_id) = request.conn_id.clone() else { + return BTreeMap::new(); + }; + + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (list_id, list) in &mut request.lists { + if let Some(cached_list) = cached.lists.get(list_id) { + if list.sort.is_empty() { + list.sort.clone_from(&cached_list.sort); + }; + if list.room_details.required_state.is_empty() { + list.room_details + .required_state + .clone_from(&cached_list.room_details.required_state); + }; + list.room_details.timeline_limit = list + .room_details + .timeline_limit + .or(cached_list.room_details.timeline_limit); + list.include_old_rooms = list + .include_old_rooms + .clone() + .or_else(|| cached_list.include_old_rooms.clone()); + match (&mut list.filters, cached_list.filters.clone()) { + (Some(list_filters), Some(cached_filters)) => { + list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); + if list_filters.spaces.is_empty() { + list_filters.spaces = cached_filters.spaces; + } + list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); + list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); + if list_filters.room_types.is_empty() { + list_filters.room_types = cached_filters.room_types; + } + if list_filters.not_room_types.is_empty() { + list_filters.not_room_types = cached_filters.not_room_types; + } + list_filters.room_name_like = list_filters + .room_name_like + .clone() + .or(cached_filters.room_name_like); + if list_filters.tags.is_empty() { + list_filters.tags = cached_filters.tags; + } + if list_filters.not_tags.is_empty() { + list_filters.not_tags = cached_filters.not_tags; + } + }, + (_, Some(cached_filters)) => list.filters = Some(cached_filters), + (Some(list_filters), _) => list.filters = Some(list_filters.clone()), + (..) => {}, + } + if list.bump_event_types.is_empty() { + list.bump_event_types + .clone_from(&cached_list.bump_event_types); + }; + } + cached.lists.insert(list_id.clone(), list.clone()); + } + + cached + .subscriptions + .extend(request.room_subscriptions.clone()); + request + .room_subscriptions + .extend(cached.subscriptions.clone()); + + request.extensions.e2ee.enabled = request + .extensions + .e2ee + .enabled + .or(cached.extensions.e2ee.enabled); + + request.extensions.to_device.enabled = request + .extensions + .to_device + .enabled + .or(cached.extensions.to_device.enabled); + + request.extensions.account_data.enabled = request + .extensions + .account_data + .enabled + .or(cached.extensions.account_data.enabled); + request.extensions.account_data.lists = request + .extensions + .account_data + .lists + .clone() + .or_else(|| cached.extensions.account_data.lists.clone()); + request.extensions.account_data.rooms = request + .extensions + .account_data + .rooms + .clone() + .or_else(|| cached.extensions.account_data.rooms.clone()); + + cached.extensions = request.extensions.clone(); + + cached.known_rooms.clone() + } + + pub fn update_sync_subscriptions( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, + subscriptions: BTreeMap, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + cached.subscriptions = subscriptions; + } + + pub fn update_sync_known_rooms( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, + new_cached_rooms: BTreeSet, globalsince: u64, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (roomid, lastsince) in cached + .known_rooms + .entry(list_id.clone()) + .or_default() + .iter_mut() + { + if !new_cached_rooms.contains(roomid) { + *lastsince = 0; + } + } + let list = cached.known_rooms.entry(list_id).or_default(); + for roomid in new_cached_rooms { + list.insert(roomid, globalsince); + } + } +} diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs deleted file mode 100644 index 791b46f01..000000000 --- a/src/service/transaction_ids/data.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, TransactionId, UserId}; - -pub struct Data { - userdevicetxnid_response: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - userdevicetxnid_response: db["userdevicetxnid_response"].clone(), - } - } - - pub(super) fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - self.userdevicetxnid_response.insert(&key, data)?; - - Ok(()) - } - - pub(super) fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - // If there's no entry, this is a new transaction - self.userdevicetxnid_response.get(&key) - } -} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 78e6337f2..72f60adb1 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,35 +1,45 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Handle, Map}; use ruma::{DeviceId, TransactionId, UserId}; pub struct Service { - pub db: Data, + db: Data, +} + +struct Data { + userdevicetxnid_response: Arc, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - self.db.add_txnid(user_id, device_id, txn_id, data) - } +#[implement(Service)] +pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8]) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); - pub fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>> { - self.db.existing_txnid(user_id, device_id, txn_id) - } + self.db.userdevicetxnid_response.insert(&key, data); +} + +// If there's no entry, this is a new transaction +#[implement(Service)] +pub async fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, +) -> Result> { + let key = (user_id, device_id, txn_id); + self.db.userdevicetxnid_response.qry(&key).await } diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs deleted file mode 100644 index ce071da09..000000000 --- a/src/service/uiaa/data.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::{error::ErrorKind, uiaa::UiaaInfo}, - CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, -}; - -pub struct Data { - userdevicesessionid_uiaarequest: RwLock>, - userdevicesessionid_uiaainfo: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - userdevicesessionid_uiaainfo: db["userdevicesessionid_uiaainfo"].clone(), - } - } - - pub(super) fn set_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - - pub(super) fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(ToOwned::to_owned) - } - - pub(super) fn update_uiaa_session( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - pub(super) fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } -} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 6041bbd34..7e2315142 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,174 +1,243 @@ -mod data; - -use std::sync::Arc; +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; -use conduit::{error, utils, utils::hash, Error, Result, Server}; -use data::Data; +use conduit::{ + err, error, implement, utils, + utils::{hash, string::EMPTY}, + Error, Result, Server, +}; +use database::{Deserialized, Map}; use ruma::{ api::client::{ error::ErrorKind, uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, }, - CanonicalJsonValue, DeviceId, UserId, + CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, }; use crate::{globals, users, Dep}; -pub const SESSION_ID_LENGTH: usize = 32; - pub struct Service { - server: Arc, + userdevicesessionid_uiaarequest: RwLock, + db: Data, services: Services, - pub db: Data, } struct Services { + server: Arc, globals: Dep, users: Dep, } +struct Data { + userdevicesessionid_uiaainfo: Arc, +} + +type RequestMap = BTreeMap; +type RequestKey = (OwnedUserId, OwnedDeviceId, String); + +pub const SESSION_ID_LENGTH: usize = 32; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - server: args.server.clone(), + userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), + db: Data { + userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), + }, services: Services { + server: args.server.clone(), globals: args.depend::("globals"), users: args.depend::("users"), }, - db: Data::new(args.db), })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.db.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), /* TODO: better session error handling (why - * is it optional in ruma?) */ - json_body, - )?; - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) - } - - pub fn try_auth( - &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth.session().map_or_else( - || Ok(uiaainfo.clone()), - |session| self.db.get_uiaa_session(user_id, device_id, session), - )?; - - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } +/// Creates a new Uiaa session. Make sure the session token is unique. +#[implement(Service)] +pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue) { + // TODO: better session error handling (why is uiaainfo.session optional in + // ruma?) + self.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + json_body, + ); + + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ); +} - match auth { - // Find out what the user completed - AuthData::Password(Password { - identifier, - password, - #[cfg(feature = "element_hacks")] - user, - .. - }) => { - #[cfg(feature = "element_hacks")] - let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { - username - } else if let Some(username) = user { - username - } else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; - - #[cfg(not(feature = "element_hacks"))] - let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier - else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; - - let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; - - // Check if password is correct - if let Some(hash) = self.services.users.password_hash(&user_id)? { - let hash_matches = hash::verify_password(password, &hash).is_ok(); - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::forbidden(), - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } +#[implement(Service)] +pub async fn try_auth( + &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, +) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = if let Some(session) = auth.session() { + self.get_uiaa_session(user_id, device_id, session).await? + } else { + uiaainfo.clone() + }; + + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - }, - AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == self.server.config.registration_token.as_deref() { - uiaainfo.completed.push(AuthType::RegistrationToken); - } else { + match auth { + // Find out what the user completed + AuthData::Password(Password { + identifier, + password, + #[cfg(feature = "element_hacks")] + user, + .. + }) => { + #[cfg(feature = "element_hacks")] + let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { + username + } else if let Some(username) = user { + username + } else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; + + #[cfg(not(feature = "element_hacks"))] + let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier + else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; + + let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; + + // Check if password is correct + if let Ok(hash) = self.services.users.password_hash(&user_id).await { + let hash_matches = hash::verify_password(password, &hash).is_ok(); + if !hash_matches { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { kind: ErrorKind::forbidden(), - message: "Invalid registration token.".to_owned(), + message: "Invalid username or password.".to_owned(), }); return Ok((false, uiaainfo)); } - }, - AuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - }, - k => error!("type not supported: {:?}", k), - } + } - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + }, + AuthData::RegistrationToken(t) => { + if Some(t.token.trim()) == self.services.server.config.registration_token.as_deref() { + uiaainfo.completed.push(AuthType::RegistrationToken); + } else { + uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::forbidden(), + message: "Invalid registration token.".to_owned(), + }); + return Ok((false, uiaainfo)); } - // We didn't break, so this flow succeeded! - completed = true; - } + }, + AuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + }, + k => error!("type not supported: {:?}", k), + } - if !completed { - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } } + // We didn't break, so this flow succeeded! + completed = true; + } - // UIAA was successful! Remove this session and return true - self.db.update_uiaa_session( + if !completed { + self.update_uiaa_session( user_id, device_id, uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) + Some(&uiaainfo), + ); + + return Ok((false, uiaainfo)); } - #[must_use] - pub fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option { - self.db.get_uiaa_request(user_id, device_id, session) + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + ); + + Ok((true, uiaainfo)) +} + +#[implement(Service)] +fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue) { + let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); + self.userdevicesessionid_uiaarequest + .write() + .expect("locked for writing") + .insert(key, request.to_owned()); +} + +#[implement(Service)] +pub fn get_uiaa_request( + &self, user_id: &UserId, device_id: Option<&DeviceId>, session: &str, +) -> Option { + let key = ( + user_id.to_owned(), + device_id.unwrap_or_else(|| EMPTY.into()).to_owned(), + session.to_owned(), + ); + + self.userdevicesessionid_uiaarequest + .read() + .expect("locked for reading") + .get(&key) + .cloned() +} + +#[implement(Service)] +fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.db.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + ); + } else { + self.db + .userdevicesessionid_uiaainfo + .remove(&userdevicesessionid); } } + +#[implement(Service)] +async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { + let key = (user_id, device_id, session); + self.db + .userdevicesessionid_uiaainfo + .qry(&key) + .await + .deserialized_json() + .map_err(|_| err!(Request(Forbidden("UIAA session does not exist.")))) +} diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index 3c69b2430..4e16e22b0 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -1,19 +1,22 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{debug, err, info, utils, warn, Error, Result}; -use database::Map; +use conduit::{debug, info, warn, Result}; +use database::{Deserialized, Map}; use ruma::events::room::message::RoomMessageEventContent; use serde::Deserialize; -use tokio::{sync::Notify, time::interval}; +use tokio::{ + sync::Notify, + time::{interval, MissedTickBehavior}, +}; use crate::{admin, client, globals, Dep}; pub struct Service { - services: Services, - db: Arc, - interrupt: Notify, interval: Duration, + interrupt: Notify, + db: Arc, + services: Services, } struct Services { @@ -22,12 +25,12 @@ struct Services { globals: Dep, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponse { updates: Vec, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponseEntry { id: u64, date: String, @@ -42,33 +45,38 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), + interrupt: Notify::new(), + db: args.db["global"].clone(), services: Services { globals: args.depend::("globals"), admin: args.depend::("admin"), client: args.depend::("client"), }, - db: args.db["global"].clone(), - interrupt: Notify::new(), - interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), })) } + #[tracing::instrument(skip_all, name = "updates", level = "trace")] async fn worker(self: Arc) -> Result<()> { if !self.services.globals.allow_check_for_updates() { debug!("Disabling update check"); return Ok(()); } + let mut i = interval(self.interval); + i.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { tokio::select! { - () = self.interrupt.notified() => return Ok(()), + () = self.interrupt.notified() => break, _ = i.tick() => (), } - if let Err(e) = self.handle_updates().await { + if let Err(e) = self.check().await { warn!(%e, "Failed to check for updates"); } } + + Ok(()) } fn interrupt(&self) { self.interrupt.notify_waiters(); } @@ -77,52 +85,52 @@ impl crate::Service for Service { } impl Service { - #[tracing::instrument(skip_all)] - async fn handle_updates(&self) -> Result<()> { + #[tracing::instrument(skip_all, level = "trace")] + async fn check(&self) -> Result<()> { let response = self .services .client .default .get(CHECK_FOR_UPDATES_URL) .send() + .await? + .text() .await?; - let response = serde_json::from_str::(&response.text().await?) - .map_err(|e| err!("Bad check for updates response: {e}"))?; - - let mut last_update_id = self.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > self.last_check_for_updates_id()? { - info!("{:#}", update.message); - self.services - .admin - .send_message(RoomMessageEventContent::text_markdown(format!( - "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", - update.date, update.message - ))) - .await; + let response = serde_json::from_str::(&response)?; + for update in &response.updates { + if update.id > self.last_check_for_updates_id().await { + self.handle(update).await; + self.update_check_for_updates_id(update.id); } } - self.update_check_for_updates_id(last_update_id)?; Ok(()) } + async fn handle(&self, update: &CheckForUpdatesResponseEntry) { + info!("{} {:#}", update.date, update.message); + self.services + .admin + .send_message(RoomMessageEventContent::text_markdown(format!( + "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", + update.date, update.message + ))) + .await + .ok(); + } + #[inline] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + pub fn update_check_for_updates_id(&self, id: u64) { self.db - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes()); } - pub fn last_check_for_updates_id(&self) -> Result { + pub async fn last_check_for_updates_id(&self) -> u64 { self.db - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) + .qry(LAST_CHECK_FOR_UPDATES_COUNT) + .await + .deserialized() + .unwrap_or(0_u64) } } diff --git a/src/service/users/data.rs b/src/service/users/data.rs deleted file mode 100644 index 70ff12e3f..000000000 --- a/src/service/users/data.rs +++ /dev/null @@ -1,1098 +0,0 @@ -use std::{collections::BTreeMap, mem::size_of, sync::Arc}; - -use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server}; -use database::Map; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, - OwnedMxcUri, OwnedUserId, UInt, UserId, -}; - -use crate::{globals, rooms, users::clean_signatures, Dep}; - -pub struct Data { - keychangeid_userid: Arc, - keyid_key: Arc, - onetimekeyid_onetimekeys: Arc, - openidtoken_expiresatuserid: Arc, - todeviceid_events: Arc, - token_userdeviceid: Arc, - userdeviceid_metadata: Arc, - userdeviceid_token: Arc, - userfilterid_filter: Arc, - userid_avatarurl: Arc, - userid_blurhash: Arc, - userid_devicelistversion: Arc, - userid_displayname: Arc, - userid_lastonetimekeyupdate: Arc, - userid_masterkeyid: Arc, - userid_password: Arc, - userid_selfsigningkeyid: Arc, - userid_usersigningkeyid: Arc, - useridprofilekey_value: Arc, - services: Services, -} - -struct Services { - server: Arc, - globals: Dep, - state_cache: Dep, - state_accessor: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - keychangeid_userid: db["keychangeid_userid"].clone(), - keyid_key: db["keyid_key"].clone(), - onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), - openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - token_userdeviceid: db["token_userdeviceid"].clone(), - userdeviceid_metadata: db["userdeviceid_metadata"].clone(), - userdeviceid_token: db["userdeviceid_token"].clone(), - userfilterid_filter: db["userfilterid_filter"].clone(), - userid_avatarurl: db["userid_avatarurl"].clone(), - userid_blurhash: db["userid_blurhash"].clone(), - userid_devicelistversion: db["userid_devicelistversion"].clone(), - userid_displayname: db["userid_displayname"].clone(), - userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - userid_masterkeyid: db["userid_masterkeyid"].clone(), - userid_password: db["userid_password"].clone(), - userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(), - userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(), - useridprofilekey_value: db["useridprofilekey_value"].clone(), - services: Services { - server: args.server.clone(), - globals: args.depend::("globals"), - state_cache: args.depend::("rooms::state_cache"), - state_accessor: args.depend::("rooms::state_accessor"), - }, - } - } - - /// Check if a user has an account on this homeserver. - #[inline] - pub(super) fn exists(&self, user_id: &UserId) -> Result { - Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) - } - - /// Check if account is deactivated - pub(super) fn is_deactivated(&self, user_id: &UserId) -> Result { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))? - .is_empty()) - } - - /// Returns the number of users registered on this server. - #[inline] - pub(super) fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } - - /// Find out which user an access token belongs to. - pub(super) fn find_from_token(&self, token: &str) -> Result> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xFF); - let user_bytes = parts - .next() - .ok_or_else(|| err!(Database("User ID in token_userdeviceid is invalid.")))?; - let device_bytes = parts - .next() - .ok_or_else(|| err!(Database("Device ID in token_userdeviceid is invalid.")))?; - - Ok(Some(( - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid. {e}")))?, - utils::string_from_bytes(device_bytes) - .map_err(|e| err!(Database("Device ID in token_userdeviceid is invalid. {e}")))?, - ))) - }) - } - - /// Returns an iterator over all users on this homeserver. - pub fn iter<'a>(&'a self) -> Box> + 'a> { - Box::new(self.userid_password.iter().map(|(bytes, _)| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("User ID in userid_password is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in userid_password is invalid. {e}"))) - })) - } - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is - /// greater then zero. - pub(super) fn list_local_users(&self) -> Result> { - let users: Vec = self - .userid_password - .iter() - .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) - } - - /// Returns the password hash for the given user. - pub(super) fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) - } - - /// Hash and set the user's password to the Argon2 hash - pub(super) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::hash::password(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } - } - - /// Returns the displayname of a user on this homeserver. - pub(super) fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("Displayname in db is invalid. {e}")))?, - )) - }) - } - - /// Sets a new displayname or removes it if displayname is None. You still - /// need to nofify all rooms of this change. - pub(super) fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the `avatar_url` of a user. - pub(super) fn avatar_url(&self, user_id: &UserId) -> Result> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s_bytes = utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database(warn!("Avatar URL in db is invalid: {e}"))))?; - let mxc_uri: OwnedMxcUri = s_bytes.into(); - Ok(mxc_uri) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the blurhash of a user. - pub(super) fn blurhash(&self, user_id: &UserId) -> Result> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Avatar URL in db is invalid. {e}"))) - }) - .transpose() - } - - /// Gets a specific user profile key - pub(super) fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - self.useridprofilekey_value - .get(&key)? - .map_or(Ok(None), |bytes| Ok(Some(serde_json::from_slice(&bytes).unwrap()))) - } - - /// Gets all the user's profile keys and values in an iterator - pub(super) fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a + Send> { - let prefix = user_id.as_bytes().to_vec(); - - Box::new( - self.useridprofilekey_value - .scan_prefix(prefix) - .map(|(key, value)| { - let profile_key_name = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("Profile key in db is invalid")))?, - ) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - let profile_key_value = serde_json::from_slice(&value) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - Ok((profile_key_name, profile_key_value)) - }), - ) - } - - /// Sets a new profile key value, removes the key if value is None - pub(super) fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(value) = profile_key_value { - let value = serde_json::to_vec(&value).unwrap(); - - self.useridprofilekey_value.insert(&key, &value) - } else { - self.useridprofilekey_value.remove(&key) - } - } - - /// Get the timezone of a user. - pub(super) fn timezone(&self, user_id: &UserId) -> Result> { - // first check the unstable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - let value = self - .useridprofilekey_value - .get(&key)? - .map(|bytes| utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}")))) - .transpose() - .unwrap(); - - // TODO: transparently migrate unstable key usage to the stable key once MSC4133 - // and MSC4175 are stable, likely a remove/insert in this block - if value.is_none() || value.as_ref().is_some_and(String::is_empty) { - // check the stable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"m.tz"); - - return self - .useridprofilekey_value - .get(&key)? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}"))) - }) - .transpose(); - } - - Ok(value) - } - - /// Sets a new timezone or removes it if timezone is None. - pub(super) fn set_timezone(&self, user_id: &UserId, timezone: Option) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(timezone) = timezone { - self.useridprofilekey_value - .insert(&key, timezone.as_bytes())?; - } else { - self.useridprofilekey_value.remove(&key)?; - } - - Ok(()) - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Adds a new device to a user. - pub(super) fn create_device( - &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, - client_ip: Option, - ) -> Result<()> { - // This method should never be called for nonexistent users. We shouldn't assert - // though... - if !self.exists(user_id)? { - warn!("Called create_device for non-existent user {} in database", user_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); - } - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: client_ip, - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) - } - - /// Removes a device from a user. - pub(super) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xFF); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) - } - - /// Returns an iterator over all device ids of this user. - pub(super) fn all_device_ids<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - // All devices have metadata - Box::new( - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("UserDevice ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("Device ID in userdeviceid_metadata is invalid. {e}")))? - .into()) - }), - ) - } - - /// Replaces the access token of one device. - pub(super) fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // should not be None, but we shouldn't assert either lol... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) - } - - pub(super) fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&key)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - key.push(0xFF); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - pub(super) fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Count in roomid_lastroomactiveupdate is invalid. {e}"))) - }) - } - - pub(super) fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))?, - ) - .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))?, - serde_json::from_slice(&value).map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))?, - )) - }) - .transpose() - } - - pub(super) fn count_one_time_keys( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - let mut counts = BTreeMap::new(); - - for algorithm in self - .onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKey ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("DeviceKeyId in db is invalid. {e}")))? - .algorithm(), - ) - }) { - let count: &mut UInt = counts.entry(algorithm?).or_default(); - *count = count.saturating_add(uint!(1)); - } - - Ok(counts) - } - - pub(super) fn add_device_keys( - &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id)?; - - Ok(()) - } - - pub(super) fn add_cross_signing_keys( - &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, - user_signing_key: &Option>, notify: bool, - ) -> Result<()> { - // TODO: Check signatures - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let (master_key_key, _) = Self::parse_master_key(user_id, master_key)?; - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - if notify { - self.mark_device_key_update(user_id)?; - } - - Ok(()) - } - - pub(super) fn sign_key( - &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(key_id.as_bytes()); - - let mut cross_signing_key: serde_json::Value = serde_json::from_slice( - &self - .keyid_key - .get(&key)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, - ) - .map_err(|e| err!(Database("key in keyid_key is invalid. {e}")))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? - .as_object_mut() - .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? - .entry(sender_id.to_string()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - self.mark_device_key_update(target_id)?; - - Ok(()) - } - - pub(super) fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> Box> + 'a> { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut start = prefix.clone(); - start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - Box::new( - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?, - ) - .map_err(|e| err!(Database("User ID in devicekeychangeid_userid is invalid. {e}"))) - }), - ) - } - - pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = self.services.globals.next_count()?.to_be_bytes(); - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - // Don't send key updates to unencrypted rooms - if self - .services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) - } - - pub(super) fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes).map_err(|e| err!(Database("DeviceKeys in db are invalid. {e}")))?, - )) - }) - } - - pub(super) fn parse_master_key( - user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; - let mut master_key_ids = master_key.keys.values(); - let master_key_id = master_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - Ok((master_key_key, master_key)) - } - - pub(super) fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?; - clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"), - ))) - }) - } - - pub(super) fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?, - )) - }) - }) - } - - pub(super) fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) - } - - pub(super) fn get_to_device_events( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result>> { - let mut events = Vec::new(); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|e| err!(Database("Event in todeviceid_events is invalid. {e}")))?, - ); - } - - Ok(events) - } - - pub(super) fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len().saturating_sub(size_of::())..key.len()]) - .map_err(|e| err!(Database("ToDeviceId has invalid count bytes. {e}")))?, - )) - }) - .filter_map(Result::ok) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) - } - - pub(super) fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!( - "Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \ - metadata in database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) - } - - /// Get device metadata. - pub(super) fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) - } - - pub(super) fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Invalid devicelistversion in db. {e}"))) - .map(Some) - }) - } - - pub(super) fn all_devices_metadata<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - - Box::new( - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes) - .map_err(|e| err!(Database("Device in userdeviceid_metadata is invalid. {e}"))) - }), - ) - } - - /// Creates a new sync filter. Returns the filter id. - pub(super) fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - let filter_id = utils::random_string(4); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter - .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; - - Ok(filter_id) - } - - pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw).map_err(|e| err!(Database("Invalid filter event in db. {e}"))) - } else { - Ok(None) - } - } - - /// Creates an OpenID token, which can be used to prove that a user has - /// access to an account (primarily for integrations) - pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - use std::num::Saturating as Sat; - - let expires_in = self.services.server.config.openid_token_ttl; - let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); - - let mut value = expires_at.0.to_be_bytes().to_vec(); - value.extend_from_slice(user_id.as_bytes()); - - self.openidtoken_expiresatuserid - .insert(token.as_bytes(), value.as_slice())?; - - Ok(expires_in) - } - - /// Find out which user an OpenID access token belongs to. - pub(super) fn find_from_openid_token(&self, token: &str) -> Result { - let Some(value) = self.openidtoken_expiresatuserid.get(token.as_bytes())? else { - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is unrecognised")); - }; - - let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); - - let expires_at = u64::from_be_bytes( - expires_at_bytes - .try_into() - .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, - ); - - if expires_at < utils::millis_since_unix_epoch() { - debug_info!("OpenID token is expired, removing"); - self.openidtoken_expiresatuserid.remove(token.as_bytes())?; - - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is expired")); - } - - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) - } -} - -/// Will only return with Some(username) if the password was not empty and the -/// username could be successfully parsed. -/// If `utils::string_from_bytes`(...) returns an error that username will be -/// skipped and the error will be logged. -pub(super) fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!("Failed to parse username while calling get_local_users(): {}", e.to_string()); - None - }, - } - } -} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 80897b5ff..9a058ba9d 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,552 +1,984 @@ -mod data; +use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc}; -use std::{ - collections::{BTreeMap, BTreeSet}, - mem, - sync::{Arc, Mutex, Mutex as StdMutex}, +use conduit::{ + debug_warn, err, utils, + utils::{stream::TryIgnore, string::Unquoted, ReadyExt, TryReadyExt}, + warn, Err, Error, Result, Server, }; - -use conduit::{Error, Result}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt}; use ruma::{ - api::client::{ - device::Device, - filter::FilterDefinition, - sync::sync_events::{ - self, - v4::{ExtensionsConfig, SyncRequestList}, - }, - }, + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::AnyToDeviceEvent, + events::{AnyToDeviceEvent, StateEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedRoomId, OwnedUserId, - UInt, UserId, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, + OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use self::data::Data; -use crate::{admin, rooms, Dep}; +use crate::{admin, globals, rooms, Dep}; pub struct Service { - connections: DbConnections, - pub db: Data, services: Services, + db: Data, } struct Services { + server: Arc, admin: Dep, + globals: Dep, + state_accessor: Dep, state_cache: Dep, } +struct Data { + keychangeid_userid: Arc, + keyid_key: Arc, + onetimekeyid_onetimekeys: Arc, + openidtoken_expiresatuserid: Arc, + todeviceid_events: Arc, + token_userdeviceid: Arc, + userdeviceid_metadata: Arc, + userdeviceid_token: Arc, + userfilterid_filter: Arc, + userid_avatarurl: Arc, + userid_blurhash: Arc, + userid_devicelistversion: Arc, + userid_displayname: Arc, + userid_lastonetimekeyupdate: Arc, + userid_masterkeyid: Arc, + userid_password: Arc, + userid_selfsigningkeyid: Arc, + userid_usersigningkeyid: Arc, + useridprofilekey_value: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - connections: StdMutex::new(BTreeMap::new()), - db: Data::new(&args), services: Services { + server: args.server.clone(), admin: args.depend::("admin"), + globals: args.depend::("globals"), + state_accessor: args.depend::("rooms::state_accessor"), state_cache: args.depend::("rooms::state_cache"), }, + db: Data { + keychangeid_userid: args.db["keychangeid_userid"].clone(), + keyid_key: args.db["keyid_key"].clone(), + onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(), + openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(), + todeviceid_events: args.db["todeviceid_events"].clone(), + token_userdeviceid: args.db["token_userdeviceid"].clone(), + userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(), + userdeviceid_token: args.db["userdeviceid_token"].clone(), + userfilterid_filter: args.db["userfilterid_filter"].clone(), + userid_avatarurl: args.db["userid_avatarurl"].clone(), + userid_blurhash: args.db["userid_blurhash"].clone(), + userid_devicelistversion: args.db["userid_devicelistversion"].clone(), + userid_displayname: args.db["userid_displayname"].clone(), + userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), + userid_masterkeyid: args.db["userid_masterkeyid"].clone(), + userid_password: args.db["userid_password"].clone(), + userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(), + userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(), + useridprofilekey_value: args.db["useridprofilekey_value"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -type DbConnections = Mutex>; -type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); -type DbConnectionsVal = Arc>; - -struct SlidingSyncCache { - lists: BTreeMap, - subscriptions: BTreeMap, - known_rooms: BTreeMap>, // For every room, the roomsince number - extensions: ExtensionsConfig, -} - impl Service { - /// Check if a user has an account on this homeserver. + /// Check if a user is an admin #[inline] - pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } - - pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { - self.connections - .lock() - .unwrap() - .contains_key(&(user_id, device_id, conn_id)) - } + pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await } - pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { - self.connections - .lock() - .unwrap() - .remove(&(user_id, device_id, conn_id)); + /// Create a new user account on this homeserver. + #[inline] + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.set_password(user_id, password) } - pub fn update_sync_request_with_cache( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, - ) -> BTreeMap> { - let Some(conn_id) = request.conn_id.clone() else { - return BTreeMap::new(); - }; - - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (list_id, list) in &mut request.lists { - if let Some(cached_list) = cached.lists.get(list_id) { - if list.sort.is_empty() { - list.sort.clone_from(&cached_list.sort); - }; - if list.room_details.required_state.is_empty() { - list.room_details - .required_state - .clone_from(&cached_list.room_details.required_state); - }; - list.room_details.timeline_limit = list - .room_details - .timeline_limit - .or(cached_list.room_details.timeline_limit); - list.include_old_rooms = list - .include_old_rooms - .clone() - .or_else(|| cached_list.include_old_rooms.clone()); - match (&mut list.filters, cached_list.filters.clone()) { - (Some(list_filters), Some(cached_filters)) => { - list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); - if list_filters.spaces.is_empty() { - list_filters.spaces = cached_filters.spaces; - } - list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); - list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); - if list_filters.room_types.is_empty() { - list_filters.room_types = cached_filters.room_types; - } - if list_filters.not_room_types.is_empty() { - list_filters.not_room_types = cached_filters.not_room_types; - } - list_filters.room_name_like = list_filters - .room_name_like - .clone() - .or(cached_filters.room_name_like); - if list_filters.tags.is_empty() { - list_filters.tags = cached_filters.tags; - } - if list_filters.not_tags.is_empty() { - list_filters.not_tags = cached_filters.not_tags; - } - }, - (_, Some(cached_filters)) => list.filters = Some(cached_filters), - (Some(list_filters), _) => list.filters = Some(list_filters.clone()), - (..) => {}, - } - if list.bump_event_types.is_empty() { - list.bump_event_types - .clone_from(&cached_list.bump_event_types); - }; - } - cached.lists.insert(list_id.clone(), list.clone()); - } + /// Deactivate account + pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + self.all_device_ids(user_id) + .for_each(|device_id| self.remove_device(user_id, device_id)) + .await; - cached - .subscriptions - .extend(request.room_subscriptions.clone()); - request - .room_subscriptions - .extend(cached.subscriptions.clone()); - - request.extensions.e2ee.enabled = request - .extensions - .e2ee - .enabled - .or(cached.extensions.e2ee.enabled); - - request.extensions.to_device.enabled = request - .extensions - .to_device - .enabled - .or(cached.extensions.to_device.enabled); - - request.extensions.account_data.enabled = request - .extensions - .account_data - .enabled - .or(cached.extensions.account_data.enabled); - request.extensions.account_data.lists = request - .extensions - .account_data - .lists - .clone() - .or_else(|| cached.extensions.account_data.lists.clone()); - request.extensions.account_data.rooms = request - .extensions - .account_data - .rooms - .clone() - .or_else(|| cached.extensions.account_data.rooms.clone()); - - cached.extensions = request.extensions.clone(); - - cached.known_rooms.clone() - } - - pub fn update_sync_subscriptions( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, - subscriptions: BTreeMap, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); + // Set the password to "" to indicate a deactivated account. Hashes will never + // result in an empty string, so the user will not be able to log in again. + // Systems like changing the password without logging in should check if the + // account is deactivated. + self.set_password(user_id, None)?; - cached.subscriptions = subscriptions; + // TODO: Unhook 3PID + Ok(()) } - pub fn update_sync_known_rooms( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, - new_cached_rooms: BTreeSet, globalsince: u64, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (roomid, lastsince) in cached - .known_rooms - .entry(list_id.clone()) - .or_default() - .iter_mut() - { - if !new_cached_rooms.contains(roomid) { - *lastsince = 0; - } - } - let list = cached.known_rooms.entry(list_id).or_default(); - for roomid in new_cached_rooms { - list.insert(roomid, globalsince); - } - } + /// Check if a user has an account on this homeserver. + #[inline] + pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.qry(user_id).await.is_ok() } /// Check if account is deactivated - pub fn is_deactivated(&self, user_id: &UserId) -> Result { self.db.is_deactivated(user_id) } - - /// Check if a user is an admin - pub fn is_admin(&self, user_id: &UserId) -> Result { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { - self.services.state_cache.is_joined(user_id, &admin_room_id) - } else { - Ok(false) - } + pub async fn is_deactivated(&self, user_id: &UserId) -> Result { + self.db + .userid_password + .qry(user_id) + .map_ok(|val| val.is_empty()) + .map_err(|_| err!(Request(NotFound("User does not exist.")))) + .await } - /// Create a new user account on this homeserver. - #[inline] - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password)?; - Ok(()) + /// Check if account is active, infallible + pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) } + + /// Check if account is active, infallible + pub async fn is_active_local(&self, user_id: &UserId) -> bool { + self.services.globals.user_is_local(user_id) && self.is_active(user_id).await } /// Returns the number of users registered on this server. #[inline] - pub fn count(&self) -> Result { self.db.count() } + pub async fn count(&self) -> usize { self.db.userid_password.count().await } /// Find out which user an access token belongs to. - pub fn find_from_token(&self, token: &str) -> Result> { - self.db.find_from_token(token) + pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> { + self.db.token_userdeviceid.qry(token).await.deserialized() } + /// Returns an iterator over all users on this homeserver (offered for + /// compatibility) + #[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)] + pub fn iter(&self) -> impl Stream + Send + '_ { self.stream().map(ToOwned::to_owned) } + /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator> + '_ { self.db.iter() } + pub fn stream(&self) -> impl Stream + Send { self.db.userid_password.keys().ignore_err() } /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is /// greater then zero. - pub fn list_local_users(&self) -> Result> { self.db.list_local_users() } + pub fn list_local_users(&self) -> impl Stream + Send + '_ { + self.db + .userid_password + .stream() + .ignore_err() + .ready_filter_map(|(u, p): (&UserId, &[u8])| (!p.is_empty()).then_some(u)) + } /// Returns the password hash for the given user. - pub fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) } + pub async fn password_hash(&self, user_id: &UserId) -> Result { + self.db.userid_password.qry(user_id).await.deserialized() + } /// Hash and set the user's password to the Argon2 hash - #[inline] pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password) + if let Some(password) = password { + if let Ok(hash) = utils::hash::password(password) { + self.db + .userid_password + .insert(user_id.as_bytes(), hash.as_bytes()); + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Password does not meet the requirements.", + )) + } + } else { + self.db.userid_password.insert(user_id.as_bytes(), b""); + Ok(()) + } } /// Returns the displayname of a user on this homeserver. - pub fn displayname(&self, user_id: &UserId) -> Result> { self.db.displayname(user_id) } + pub async fn displayname(&self, user_id: &UserId) -> Result { + self.db.userid_displayname.qry(user_id).await.deserialized() + } /// Sets a new displayname or removes it if displayname is None. You still /// need to nofify all rooms of this change. - pub async fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - self.db.set_displayname(user_id, displayname) + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) { + if let Some(displayname) = displayname { + self.db + .userid_displayname + .insert(user_id.as_bytes(), displayname.as_bytes()); + } else { + self.db.userid_displayname.remove(user_id.as_bytes()); + } } - /// Get the avatar_url of a user. - pub fn avatar_url(&self, user_id: &UserId) -> Result> { self.db.avatar_url(user_id) } + /// Get the `avatar_url` of a user. + pub async fn avatar_url(&self, user_id: &UserId) -> Result { + self.db.userid_avatarurl.qry(user_id).await.deserialized() + } /// Sets a new avatar_url or removes it if avatar_url is None. - pub async fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - self.db.set_avatar_url(user_id, avatar_url) + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) { + if let Some(avatar_url) = avatar_url { + self.db + .userid_avatarurl + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes()); + } else { + self.db.userid_avatarurl.remove(user_id.as_bytes()); + } } /// Get the blurhash of a user. - pub fn blurhash(&self, user_id: &UserId) -> Result> { self.db.blurhash(user_id) } - - pub fn timezone(&self, user_id: &UserId) -> Result> { self.db.timezone(user_id) } - - /// Gets a specific user profile key - pub fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result> { - self.db.profile_key(user_id, profile_key) + pub async fn blurhash(&self, user_id: &UserId) -> Result { + self.db.userid_blurhash.qry(user_id).await.deserialized() } - /// Gets all the user's profile keys and values in an iterator - pub fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a + Send> { - self.db.all_profile_keys(user_id) - } - - /// Sets a new profile key value, removes the key if value is None - pub fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option, - ) -> Result<()> { - self.db - .set_profile_key(user_id, profile_key, profile_key_value) - } - - /// Sets a new tz or removes it if tz is None. - pub async fn set_timezone(&self, user_id: &UserId, tz: Option) -> Result<()> { - self.db.set_timezone(user_id, tz) - } - - /// Sets a new blurhash or removes it if blurhash is None. - pub async fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - self.db.set_blurhash(user_id, blurhash) + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) { + if let Some(blurhash) = blurhash { + self.db + .userid_blurhash + .insert(user_id.as_bytes(), blurhash.as_bytes()); + } else { + self.db.userid_blurhash.remove(user_id.as_bytes()); + } } /// Adds a new device to a user. - pub fn create_device( + pub async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, client_ip: Option, ) -> Result<()> { - self.db - .create_device(user_id, device_id, token, initial_device_display_name, client_ip) + // This method should never be called for nonexistent users. We shouldn't assert + // though... + if !self.exists(user_id).await { + warn!("Called create_device for non-existent user {} in database", user_id); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); + } + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(&Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: client_ip, + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }) + .expect("Device::to_string never fails."), + ); + + self.set_token(user_id, device_id, token).await?; + + Ok(()) } /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.remove_device(user_id, device_id) + pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Remove tokens + if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await { + self.db.userdeviceid_token.remove(&userdeviceid); + self.db.token_userdeviceid.remove(&old_token); + } + + // Remove todevice events + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.todeviceid_events.remove(key)) + .await; + + // TODO: Remove onetimekeys + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.remove(&userdeviceid); } /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { - self.db.all_device_ids(user_id) + pub fn all_device_ids<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.db + .userdeviceid_metadata + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, device_id): (Ignore, &DeviceId)| device_id) } /// Replaces the access token of one device. - #[inline] - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - self.db.set_token(user_id, device_id, token) + pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let key = (user_id, device_id); + // should not be None, but we shouldn't assert either lol... + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + // Remove old token + if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { + self.db.token_userdeviceid.remove(&old_token); + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + self.db + .userdeviceid_token + .insert(&userdeviceid, token.as_bytes()); + self.db + .token_userdeviceid + .insert(token.as_bytes(), &userdeviceid); + + Ok(()) } - pub fn add_one_time_key( + pub async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, ) -> Result<()> { + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + + self.db.onetimekeyid_onetimekeys.insert( + &key, + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + ); + self.db - .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) - } + .userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); - // TODO: use this ? - #[allow(dead_code)] - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.db.last_one_time_keys_update(user_id) + Ok(()) } - pub fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - self.db.take_one_time_key(user_id, device_id, key_algorithm) + pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { + self.db + .userid_lastonetimekeyupdate + .qry(user_id) + .await + .deserialized() + .unwrap_or(0) } - pub fn count_one_time_keys( + pub async fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, + ) -> Result<(OwnedDeviceKeyId, Raw)> { + self.db + .userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + let one_time_key = self + .db + .onetimekeyid_onetimekeys + .raw_stream_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + self.db.onetimekeyid_onetimekeys.remove(key); + + let key = key + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid."))) + .unwrap(); + + let key = serde_json::from_slice(key) + .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}"))) + .unwrap(); + + let val = serde_json::from_slice(val) + .map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}"))) + .unwrap(); + + (key, val) + }) + .next() + .await; + + one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found")))) + } + + pub async fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - self.db.count_one_time_keys(user_id, device_id) - } + ) -> BTreeMap { + type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); + + let mut algorithm_counts = BTreeMap::::new(); + let query = (user_id, device_id); + self.db + .onetimekeyid_onetimekeys + .stream_prefix(&query) + .ignore_err() + .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { + let device_key_id: &DeviceKeyId = device_key_id + .as_str() + .try_into() + .expect("Invalid DeviceKeyID in database"); + + let count: &mut UInt = algorithm_counts + .entry(device_key_id.algorithm()) + .or_default(); + + *count = count.saturating_add(1_u32.into()); + }) + .await; + + algorithm_counts + } + + pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.db.keyid_key.insert( + &userdeviceid, + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + ); - pub fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { - self.db.add_device_keys(user_id, device_id, device_keys) + self.mark_device_key_update(user_id).await; } - pub fn add_cross_signing_keys( + pub async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, notify: bool, ) -> Result<()> { + // TODO: Check signatures + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let (master_key_key, _) = parse_master_key(user_id, master_key)?; + + self.db + .keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes()); + self.db - .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) + .userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key); + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes()); + + self.db + .userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key); + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids + .next() + .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?; + + if user_signing_key_ids.next().is_some() { + return Err!(Request(InvalidParam("User signing key contained more than one key."))); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes()); + + self.db + .userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key); + } + + if notify { + self.mark_device_key_update(user_id).await; + } + + Ok(()) } - pub fn sign_key( + pub async fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result<()> { - self.db.sign_key(target_id, key_id, signature, sender_id) + let key = (target_id, key_id); + + let mut cross_signing_key: serde_json::Value = self + .db + .keyid_key + .qry(&key) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key."))))? + .deserialized_json() + .map_err(|e| err!(Database("key in keyid_key is invalid. {e:?}")))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? + .as_object_mut() + .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? + .entry(sender_id.to_string()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? + .insert(signature.0, signature.1.into()); + + let mut key = target_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(key_id.as_bytes()); + self.db.keyid_key.insert( + &key, + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + ); + + self.mark_device_key_update(target_id).await; + + Ok(()) } pub fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> impl Iterator> + 'a { - self.db.keys_changed(user_or_room_id, from, to) - } + &'a self, user_or_room_id: &'a str, from: u64, to: Option, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = ((&'a str, u64), &'a UserId); - #[inline] - pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { self.db.mark_device_key_update(user_id) } + let to = to.unwrap_or(u64::MAX); + let start = (user_or_room_id, from.saturating_add(1)); + self.db + .keychangeid_userid + .stream_from(&start) + .ignore_err() + .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to) + .map(|((..), user_id): KeyVal<'_>| user_id) + } + + pub async fn mark_device_key_update(&self, user_id: &UserId) { + let count = self.services.globals.next_count().unwrap().to_be_bytes(); + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + // Don't send key updates to unencrypted rooms + if self + .services + .state_accessor + .room_state_get(room_id, &StateEventType::RoomEncryption, "") + .await + .is_err() + { + continue; + } - pub fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - self.db.get_device_keys(user_id, device_id) - } + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); - #[inline] - pub fn parse_master_key( - &self, user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - Data::parse_master_key(user_id, master_key) + self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); } - #[inline] - pub fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_key(key, sender_user, user_id, allowed_signatures) + pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result> { + let key_id = (user_id, device_id); + self.db.keyid_key.qry(&key_id).await.deserialized_json() } - pub fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_master_key(sender_user, user_id, allowed_signatures) + pub async fn get_key( + &self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key = self + .db + .keyid_key + .qry(key_id) + .await + .deserialized_json::()?; + + let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; + let raw_value = serde_json::value::to_raw_value(&cleaned)?; + Ok(Raw::from_json(raw_value)) + } + + pub async fn get_master_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_masterkeyid.qry(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await } - pub fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_self_signing_key(sender_user, user_id, allowed_signatures) + pub async fn get_self_signing_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_selfsigningkeyid.qry(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await } - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.db.get_user_signing_key(user_id) + pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { + let key_id = self.db.userid_usersigningkeyid.qry(user_id).await?; + + self.db.keyid_key.qry(&*key_id).await.deserialized_json() } - pub fn add_to_device_event( + pub async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, - ) -> Result<()> { - self.db - .add_to_device_event(sender, target_user_id, target_device_id, event_type, content) - } + ) { + let mut key = target_user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(target_device_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); + + let mut json = serde_json::Map::new(); + json.insert("type".to_owned(), event_type.to_owned().into()); + json.insert("sender".to_owned(), sender.to_string().into()); + json.insert("content".to_owned(), content); + + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - self.db.get_to_device_events(user_id, device_id) + self.db.todeviceid_events.insert(&key, &value); } - pub fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - self.db.remove_to_device_events(user_id, device_id, until) + pub fn get_to_device_events<'a>( + &'a self, user_id: &'a UserId, device_id: &'a DeviceId, + ) -> impl Stream> + Send + 'a { + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .stream_raw_prefix(&prefix) + .ready_and_then(|(_, val)| serde_json::from_slice(val).map_err(Into::into)) + .ignore_err() } - pub fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - self.db.update_device_metadata(user_id, device_id, device) + pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + self.db + .todeviceid_events + .rev_raw_keys_from(&last) // this includes last + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|key| { + let len = key.len(); + let start = len.saturating_sub(size_of::()); + let count = utils::u64_from_u8(&key[start..len]); + (key, count) + }) + .ready_take_while(move |(_, count)| *count <= until) + .ready_for_each(|(key, _)| self.db.todeviceid_events.remove(&key)) + .boxed() + .await; + } + + pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "Called update_device_metadata for a non-existent user and/or device" + ))); + } + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + self.db.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), + ); + + Ok(()) } /// Get device metadata. - pub fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { - self.db.get_device_metadata(user_id, device_id) + pub async fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result { + self.db + .userdeviceid_metadata + .qry(&(user_id, device_id)) + .await + .deserialized_json() } - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.db.get_devicelist_version(user_id) + pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { + self.db + .userid_devicelistversion + .qry(user_id) + .await + .deserialized() } - pub fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { - self.db.all_devices_metadata(user_id) + pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + self.db + .userdeviceid_metadata + .stream_raw_prefix(&(user_id, Interfix)) + .ready_and_then(|(_, val)| serde_json::from_slice::(val).map_err(Into::into)) + .ignore_err() } - /// Deactivate account - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String { + let filter_id = utils::random_string(4); - // Set the password to "" to indicate a deactivated account. Hashes will never - // result in an empty string, so the user will not be able to log in again. - // Systems like changing the password without logging in should check if the - // account is deactivated. - self.db.set_password(user_id, None)?; + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(filter_id.as_bytes()); - // TODO: Unhook 3PID - Ok(()) - } + self.db + .userfilterid_filter + .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json")); - /// Creates a new sync filter. Returns the filter id. - pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - self.db.create_filter(user_id, filter) + filter_id } - pub fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - self.db.get_filter(user_id, filter_id) + pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result { + self.db + .userfilterid_filter + .qry(&(user_id, filter_id)) + .await + .deserialized_json() } /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - self.db.create_openid_token(user_id, token) + use std::num::Saturating as Sat; + + let expires_in = self.services.server.config.openid_token_ttl; + let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); + + let mut value = expires_at.0.to_be_bytes().to_vec(); + value.extend_from_slice(user_id.as_bytes()); + + self.db + .openidtoken_expiresatuserid + .insert(token.as_bytes(), value.as_slice()); + + Ok(expires_in) } /// Find out which user an OpenID access token belongs to. - pub fn find_from_openid_token(&self, token: &str) -> Result { self.db.find_from_openid_token(token) } + pub async fn find_from_openid_token(&self, token: &str) -> Result { + let Ok(value) = self.db.openidtoken_expiresatuserid.qry(token).await else { + return Err!(Request(Unauthorized("OpenID token is unrecognised"))); + }; + + let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); + let expires_at = u64::from_be_bytes( + expires_at_bytes + .try_into() + .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, + ); + + if expires_at < utils::millis_since_unix_epoch() { + debug_warn!("OpenID token is expired, removing"); + self.db.openidtoken_expiresatuserid.remove(token.as_bytes()); + + return Err!(Request(Unauthorized("OpenID token is expired"))); + } + + let user_string = utils::string_from_bytes(user_bytes) + .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?; + + UserId::parse(user_string).map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) + } + + /// Gets a specific user profile key + pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result { + let key = (user_id, profile_key); + self.db + .useridprofilekey_value + .qry(&key) + .await + .deserialized() + } + + /// Gets all the user's profile keys and values in an iterator + pub fn all_profile_keys<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream + 'a + Send { + type KeyVal = ((Ignore, String), serde_json::Value); + + let prefix = (user_id, Interfix); + self.db + .useridprofilekey_value + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, key), val): KeyVal| (key, val)) + } + + /// Sets a new profile key value, removes the key if value is None + pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(profile_key.as_bytes()); + + // TODO: insert to the stable MSC4175 key when it's stable + if let Some(value) = profile_key_value { + let value = serde_json::to_vec(&value).unwrap(); + + self.db.useridprofilekey_value.insert(&key, &value); + } else { + self.db.useridprofilekey_value.remove(&key); + } + } + + /// Get the timezone of a user. + pub async fn timezone(&self, user_id: &UserId) -> Result { + // TODO: transparently migrate unstable key usage to the stable key once MSC4133 + // and MSC4175 are stable, likely a remove/insert in this block. + + // first check the unstable prefix then check the stable prefix + let unstable_key = (user_id, "us.cloke.msc4175.tz"); + let stable_key = (user_id, "m.tz"); + self.db + .useridprofilekey_value + .qry(&unstable_key) + .or_else(|_| self.db.useridprofilekey_value.qry(&stable_key)) + .await + .deserialized() + } + + /// Sets a new timezone or removes it if timezone is None. + pub fn set_timezone(&self, user_id: &UserId, timezone: Option) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(b"us.cloke.msc4175.tz"); + + // TODO: insert to the stable MSC4175 key when it's stable + if let Some(timezone) = timezone { + self.db + .useridprofilekey_value + .insert(&key, timezone.as_bytes()); + } else { + self.db.useridprofilekey_value.remove(&key); + } + } +} + +pub fn parse_master_key(user_id: &UserId, master_key: &Raw) -> Result<(Vec, CrossSigningKey)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let master_key = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let mut master_key_ids = master_key.keys.values(); + let master_key_id = master_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + Ok((master_key_key, master_key)) } /// Ensure that a user only sees signatures from themselves and the target user -pub fn clean_signatures bool>( - cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, -) -> Result<(), Error> { +fn clean_signatures( + mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, +) -> Result +where + F: Fn(&UserId) -> bool + Send + Sync, +{ if let Some(signatures) = cross_signing_key .get_mut("signatures") .and_then(|v| v.as_object_mut()) @@ -563,5 +995,12 @@ pub fn clean_signatures bool>( } } - Ok(()) + Ok(cross_signing_key) +} + +//TODO: this is an ABA +fn increment(db: &Arc, key: &[u8]) { + let old = db.get(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, &new); } From 4776fe66c4a9d5cbb0153e8ff23009d21ed5010e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 28 Sep 2024 15:14:48 +0000 Subject: [PATCH 013/245] handle serde_json for deserialized() Signed-off-by: Jason Volk --- src/database/de.rs | 92 +++++++++++++++++++++------ src/database/deserialized.rs | 14 ---- src/database/handle.rs | 28 -------- src/service/account_data/mod.rs | 2 +- src/service/appservice/data.rs | 2 +- src/service/globals/data.rs | 5 +- src/service/key_backups/mod.rs | 12 +--- src/service/pusher/mod.rs | 2 +- src/service/rooms/outlier/mod.rs | 4 +- src/service/rooms/state_cache/data.rs | 10 +-- src/service/rooms/timeline/data.rs | 15 ++--- src/service/uiaa/mod.rs | 2 +- src/service/users/mod.rs | 12 ++-- 13 files changed, 94 insertions(+), 106 deletions(-) diff --git a/src/database/de.rs b/src/database/de.rs index 8ce25aa31..a5d2c1272 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -58,10 +58,15 @@ impl<'de> Deserializer<'de> { } #[inline] - fn record_trail(&mut self) -> &'de [u8] { - let record = &self.buf[self.pos..]; - self.inc_pos(record.len()); - record + fn record_next_peek_byte(&self) -> Option { + let started = self.pos != 0; + let buf = &self.buf[self.pos..]; + debug_assert!( + !started || buf[0] == Self::SEP, + "Missing expected record separator at current position" + ); + + buf.get::(started.into()).copied() } #[inline] @@ -75,6 +80,13 @@ impl<'de> Deserializer<'de> { self.inc_pos(started.into()); } + #[inline] + fn record_trail(&mut self) -> &'de [u8] { + let record = &self.buf[self.pos..]; + self.inc_pos(record.len()); + record + } + #[inline] fn inc_pos(&mut self, n: usize) { self.pos = self.pos.saturating_add(n); @@ -85,41 +97,44 @@ impl<'de> Deserializer<'de> { impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { type Error = Error; - fn deserialize_map(self, _visitor: V) -> Result + fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, { - unimplemented!("deserialize Map not implemented") + visitor.visit_seq(self) } - fn deserialize_seq(self, visitor: V) -> Result + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_seq(self) } - fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + fn deserialize_tuple_struct(self, _name: &'static str, _len: usize, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_seq(self) } - fn deserialize_tuple_struct(self, _name: &'static str, _len: usize, visitor: V) -> Result + fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { - visitor.visit_seq(self) + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_map(visitor).map_err(Into::into) } - fn deserialize_struct( - self, _name: &'static str, _fields: &'static [&'static str], _visitor: V, - ) -> Result + fn deserialize_struct(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result where V: Visitor<'de>, { - unimplemented!("deserialize Struct not implemented") + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_struct(name, fields, visitor) + .map_err(Into::into) } fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result @@ -134,11 +149,14 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { visitor.visit_unit() } - fn deserialize_newtype_struct(self, _name: &'static str, _visitor: V) -> Result + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result where V: Visitor<'de>, { - unimplemented!("deserialize Newtype Struct not implemented") + match name { + "$serde_json::private::RawValue" => visitor.visit_map(self), + _ => visitor.visit_newtype_struct(self), + } } fn deserialize_enum( @@ -228,19 +246,31 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_unit>(self, _visitor: V) -> Result { - unimplemented!("deserialize Unit Struct not implemented") + unimplemented!("deserialize Unit not implemented") } - fn deserialize_identifier>(self, _visitor: V) -> Result { - unimplemented!("deserialize Identifier not implemented") + // this only used for $serde_json::private::RawValue at this time; see MapAccess + fn deserialize_identifier>(self, visitor: V) -> Result { + let input = "$serde_json::private::RawValue"; + visitor.visit_borrowed_str(input) } fn deserialize_ignored_any>(self, _visitor: V) -> Result { unimplemented!("deserialize Ignored Any not implemented") } - fn deserialize_any>(self, _visitor: V) -> Result { - unimplemented!("deserialize any not implemented") + fn deserialize_any>(self, visitor: V) -> Result { + debug_assert_eq!( + conduit::debug::type_name::(), + "serde_json::value::de::::deserialize::ValueVisitor", + "deserialize_any: type not expected" + ); + + match self.record_next_peek_byte() { + Some(b'{') => self.deserialize_map(visitor), + _ => self.deserialize_str(visitor), + } } } @@ -259,3 +289,23 @@ impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> { seed.deserialize(&mut **self).map(Some) } } + +// this only used for $serde_json::private::RawValue at this time. our db +// schema doesn't have its own map format; we use json for that anyway +impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self) + } +} diff --git a/src/database/deserialized.rs b/src/database/deserialized.rs index 7da112d5f..a59b2ce54 100644 --- a/src/database/deserialized.rs +++ b/src/database/deserialized.rs @@ -9,11 +9,6 @@ pub trait Deserialized { F: FnOnce(T) -> U, T: for<'de> Deserialize<'de>; - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>; - #[inline] fn deserialized(self) -> Result where @@ -22,13 +17,4 @@ pub trait Deserialized { { self.map_de(identity::) } - - #[inline] - fn deserialized_json(self) -> Result - where - T: for<'de> Deserialize<'de>, - Self: Sized, - { - self.map_json(identity::) - } } diff --git a/src/database/handle.rs b/src/database/handle.rs index 89d87137a..0d4bd02ea 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -48,15 +48,6 @@ impl AsRef for Handle<'_> { } impl Deserialized for Result> { - #[inline] - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>, - { - self?.map_json(f) - } - #[inline] fn map_de(self, f: F) -> Result where @@ -68,15 +59,6 @@ impl Deserialized for Result> { } impl<'a> Deserialized for Result<&'a Handle<'a>> { - #[inline] - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>, - { - self.and_then(|handle| handle.map_json(f)) - } - #[inline] fn map_de(self, f: F) -> Result where @@ -88,16 +70,6 @@ impl<'a> Deserialized for Result<&'a Handle<'a>> { } impl<'a> Deserialized for &'a Handle<'a> { - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>, - { - serde_json::from_slice::(self.as_ref()) - .map_err(Into::into) - .map(f) - } - fn map_de(self, f: F) -> Result where F: FnOnce(T) -> U, diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index b4eb143d4..4f00cff1c 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -108,7 +108,7 @@ pub async fn get( .qry(&key) .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.qry(&roomuserdataid)) .await - .deserialized_json() + .deserialized() } /// Returns all changes to the account data that happened after `since`. diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index d5fa5476f..f31c5e636 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -40,7 +40,7 @@ impl Data { self.id_appserviceregistrations .qry(id) .await - .deserialized_json() + .deserialized() .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) } diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 3286e40c5..76f979441 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -305,10 +305,7 @@ impl Data { } pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { - self.server_signingkeys - .qry(origin) - .await - .deserialized_json() + self.server_signingkeys.qry(origin).await.deserialized() } pub async fn database_version(&self) -> u64 { self.global.qry("version").await.deserialized().unwrap_or(0) } diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 12712e793..decf32f7f 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -166,11 +166,7 @@ pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw Result> { let key = (user_id, version); - self.db - .backupid_algorithm - .qry(&key) - .await - .deserialized_json() + self.db.backupid_algorithm.qry(&key).await.deserialized() } #[implement(Service)] @@ -278,11 +274,7 @@ pub async fn get_session( ) -> Result> { let key = (user_id, version, room_id, session_id); - self.db - .backupkeyid_backup - .qry(&key) - .await - .deserialized_json() + self.db.backupkeyid_backup.qry(&key).await.deserialized() } #[implement(Service)] diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 44ff1945c..8d8b553fe 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -90,7 +90,7 @@ impl Service { .senderkey_pusher .qry(&senderkey) .await - .deserialized_json() + .deserialized() } pub async fn get_pushers(&self, sender: &UserId) -> Vec { diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 277b59826..4c9225ae8 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -33,7 +33,7 @@ pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result Result { .eventid_outlierpdu .qry(event_id) .await - .deserialized_json() + .deserialized() } /// Append the PDU as an outlier. diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 38e504f6b..f3ccaf102 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -156,10 +156,7 @@ impl Data { &self, user_id: &UserId, room_id: &RoomId, ) -> Result>> { let key = (user_id, room_id); - self.userroomid_invitestate - .qry(&key) - .await - .deserialized_json() + self.userroomid_invitestate.qry(&key).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] @@ -167,10 +164,7 @@ impl Data { &self, user_id: &UserId, room_id: &RoomId, ) -> Result>> { let key = (user_id, room_id); - self.userroomid_leftstate - .qry(&key) - .await - .deserialized_json() + self.userroomid_leftstate.qry(&key).await.deserialized() } /// Returns an iterator over all rooms a user left. diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index cd746be43..314dcb9fd 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -90,17 +90,14 @@ impl Data { return Ok(pdu); } - self.eventid_outlierpdu - .qry(event_id) - .await - .deserialized_json() + self.eventid_outlierpdu.qry(event_id).await.deserialized() } /// Returns the json of a pdu. pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.qry(&pduid).await.deserialized_json() + self.pduid_pdu.qry(&pduid).await.deserialized() } /// Returns the pdu's id. @@ -113,7 +110,7 @@ impl Data { pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.qry(&pduid).await.deserialized_json() + self.pduid_pdu.qry(&pduid).await.deserialized() } /// Like get_non_outlier_pdu(), but without the expense of fetching and @@ -137,7 +134,7 @@ impl Data { self.eventid_outlierpdu .qry(event_id) .await - .deserialized_json() + .deserialized() .map(Arc::new) } @@ -162,12 +159,12 @@ impl Data { /// /// This does __NOT__ check the outliers `Tree`. pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { - self.pduid_pdu.qry(pdu_id).await.deserialized_json() + self.pduid_pdu.qry(pdu_id).await.deserialized() } /// Returns the pdu as a `BTreeMap`. pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { - self.pduid_pdu.qry(pdu_id).await.deserialized_json() + self.pduid_pdu.qry(pdu_id).await.deserialized() } pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 7e2315142..0415bfc23 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -238,6 +238,6 @@ async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session .userdevicesessionid_uiaainfo .qry(&key) .await - .deserialized_json() + .deserialized() .map_err(|_| err!(Request(Forbidden("UIAA session does not exist.")))) } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 9a058ba9d..ca37ed9dc 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -577,7 +577,7 @@ impl Service { .qry(&key) .await .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key."))))? - .deserialized_json() + .deserialized() .map_err(|e| err!(Database("key in keyid_key is invalid. {e:?}")))?; let signatures = cross_signing_key @@ -652,7 +652,7 @@ impl Service { pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result> { let key_id = (user_id, device_id); - self.db.keyid_key.qry(&key_id).await.deserialized_json() + self.db.keyid_key.qry(&key_id).await.deserialized() } pub async fn get_key( @@ -666,7 +666,7 @@ impl Service { .keyid_key .qry(key_id) .await - .deserialized_json::()?; + .deserialized::()?; let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; let raw_value = serde_json::value::to_raw_value(&cleaned)?; @@ -700,7 +700,7 @@ impl Service { pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { let key_id = self.db.userid_usersigningkeyid.qry(user_id).await?; - self.db.keyid_key.qry(&*key_id).await.deserialized_json() + self.db.keyid_key.qry(&*key_id).await.deserialized() } pub async fn add_to_device_event( @@ -791,7 +791,7 @@ impl Service { .userdeviceid_metadata .qry(&(user_id, device_id)) .await - .deserialized_json() + .deserialized() } pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { @@ -830,7 +830,7 @@ impl Service { .userfilterid_filter .qry(&(user_id, filter_id)) .await - .deserialized_json() + .deserialized() } /// Creates an OpenID token, which can be used to prove that a user has From 3f7ec4221d89767e5bf0ff3e2a64c847a8dce264 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 25 Sep 2024 03:52:28 +0000 Subject: [PATCH 014/245] minor auth_chain optimizations/cleanup Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 53 +++++++++--------- src/api/server/event_auth.rs | 4 +- src/api/server/send_join.rs | 7 +-- src/api/server/state.rs | 4 +- src/api/server/state_ids.rs | 4 +- src/service/rooms/auth_chain/data.rs | 75 ++++++++++++++------------ src/service/rooms/auth_chain/mod.rs | 58 +++++++++----------- src/service/rooms/event_handler/mod.rs | 44 ++++++++------- 8 files changed, 128 insertions(+), 121 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 65c9bc712..350e08c6a 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -27,33 +27,32 @@ pub(super) async fn echo(&self, message: Vec) -> Result) -> Result { - let event_id = Arc::::from(event_id); - if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await { - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - - let start = Instant::now(); - let count = self - .services - .rooms - .auth_chain - .event_ids_iter(room_id, vec![event_id]) - .await? - .count() - .await; - - let elapsed = start.elapsed(); - Ok(RoomMessageEventContent::text_plain(format!( - "Loaded auth chain with length {count} in {elapsed:?}" - ))) - } else { - Ok(RoomMessageEventContent::text_plain("Event not found.")) - } + let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await else { + return Ok(RoomMessageEventContent::notice_plain("Event not found.")); + }; + + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + + let start = Instant::now(); + let count = self + .services + .rooms + .auth_chain + .event_ids_iter(room_id, &[&event_id]) + .await? + .count() + .await; + + let elapsed = start.elapsed(); + Ok(RoomMessageEventContent::text_plain(format!( + "Loaded auth chain with length {count} in {elapsed:?}" + ))) } #[admin_command] diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 6ec00b501..8307a4ad3 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{Error, Result}; @@ -57,7 +57,7 @@ pub(crate) async fn get_event_authorization_route( let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(room_id, &[body.event_id.borrow()]) .await? .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 639fcafd0..f92576904 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::collections::BTreeMap; +use std::{borrow::Borrow, collections::BTreeMap}; use axum::extract::State; use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; @@ -11,7 +11,7 @@ use ruma::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, }, - CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, + CanonicalJsonValue, EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::Services; @@ -196,10 +196,11 @@ async fn create_join_event( .try_collect() .await?; + let starting_events: Vec<&EventId> = state_ids.values().map(Borrow::borrow).collect(); let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, state_ids.values().cloned().collect()) + .event_ids_iter(room_id, &starting_events) .await? .map(Ok) .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 37a14a3f3..3a27cd0a3 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; @@ -63,7 +63,7 @@ pub(crate) async fn get_room_state_route( let auth_chain = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) .await? .map(Ok) .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 95ca65aa7..b026abf1d 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{err, Err}; @@ -55,7 +55,7 @@ pub(crate) async fn get_room_state_ids_route( let auth_chain_ids = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) .await? .map(|id| (*id).to_owned()) .collect() diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 3d00374e7..5c9dbda83 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{utils, utils::math::usize_from_f64, Result}; +use conduit::{err, utils, utils::math::usize_from_f64, Err, Result}; use database::Map; use lru_cache::LruCache; @@ -24,54 +24,63 @@ impl Data { } } - pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); + if let Some(result) = self + .auth_chain_cache + .lock() + .expect("cache locked") + .get_mut(key) + { + return Ok(Arc::clone(result)); } // We only save auth chains for single events in the db - if key.len() == 1 { - // Check DB cache - let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| { - chain - .chunks_exact(size_of::()) - .map(utils::u64_from_u8) - .collect::>() - }); + if key.len() != 1 { + return Err!(Request(NotFound("auth_chain not cached"))); + } - if let Ok(chain) = chain { - // Cache in RAM - self.auth_chain_cache - .lock() - .expect("locked") - .insert(vec![key[0]], Arc::clone(&chain)); + // Check database + let chain = self + .shorteventid_authchain + .qry(&key[0]) + .await + .map_err(|_| err!(Request(NotFound("auth_chain not found"))))?; - return Ok(Some(chain)); - } - } + let chain = chain + .chunks_exact(size_of::()) + .map(utils::u64_from_u8) + .collect::>(); + + // Cache in RAM + self.auth_chain_cache + .lock() + .expect("cache locked") + .insert(vec![key[0]], Arc::clone(&chain)); - Ok(None) + Ok(chain) } - pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { + pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Only persist single events in db if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::>(), - ); + let key = key[0].to_be_bytes(); + let val = auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::>(); + + self.shorteventid_authchain.insert(&key, &val); } // Cache in RAM self.auth_chain_cache .lock() - .expect("locked") + .expect("cache locked") .insert(key, auth_chain); - - Ok(()) } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 7bc239d7b..eae13b74a 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -37,25 +37,18 @@ impl crate::Service for Service { } impl Service { - pub async fn event_ids_iter<'a>( - &'a self, room_id: &RoomId, starting_events_: Vec>, - ) -> Result> + Send + 'a> { - let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); - for starting_event in &starting_events_ { - starting_events.push(starting_event); - } - - Ok(self - .get_auth_chain(room_id, &starting_events) - .await? - .into_iter() - .stream() - .filter_map(|sid| { - self.services - .short - .get_eventid_from_short(sid) - .map(Result::ok) - })) + pub async fn event_ids_iter( + &self, room_id: &RoomId, starting_events: &[&EventId], + ) -> Result> + Send + '_> { + let chain = self.get_auth_chain(room_id, starting_events).await?; + let iter = chain.into_iter().stream().filter_map(|sid| { + self.services + .short + .get_eventid_from_short(sid) + .map(Result::ok) + }); + + Ok(iter) } #[tracing::instrument(skip_all, name = "auth_chain")] @@ -93,7 +86,7 @@ impl Service { } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? { + if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); hits = hits.saturating_add(1); @@ -104,13 +97,13 @@ impl Service { let mut misses2: usize = 0; let mut chunk_cache = Vec::with_capacity(chunk.len()); for (sevent_id, event_id) in chunk { - if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? { + if let Ok(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await { trace!(?event_id, "Found cache entry for event"); chunk_cache.extend(cached.iter().copied()); hits2 = hits2.saturating_add(1); } else { let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; - self.cache_auth_chain(vec![sevent_id], &auth_chain)?; + self.cache_auth_chain(vec![sevent_id], &auth_chain); chunk_cache.extend(auth_chain.iter()); misses2 = misses2.saturating_add(1); debug!( @@ -125,7 +118,7 @@ impl Service { chunk_cache.sort_unstable(); chunk_cache.dedup(); - self.cache_auth_chain_vec(chunk_key, &chunk_cache)?; + self.cache_auth_chain_vec(chunk_key, &chunk_cache); full_auth_chain.extend(chunk_cache.iter()); misses = misses.saturating_add(1); debug!( @@ -163,11 +156,11 @@ impl Service { Ok(pdu) => { if pdu.room_id != room_id { return Err!(Request(Forbidden( - "auth event {event_id:?} for incorrect room {} which is not {}", + "auth event {event_id:?} for incorrect room {} which is not {room_id}", pdu.room_id, - room_id ))); } + for auth_event in &pdu.auth_events { let sauthevent = self .services @@ -187,20 +180,21 @@ impl Service { Ok(found) } - pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + #[inline] + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) + pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) { + let val = auth_chain.iter().copied().collect::>(); + self.db.cache_auth_chain(key, val); } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) + pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) { + let val = auth_chain.iter().copied().collect::>(); + self.db.cache_auth_chain(key, val); } pub fn get_cache_usage(&self) -> (usize, usize) { diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 07d6e4db9..57b877064 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,6 +1,7 @@ mod parse_incoming_pdu; use std::{ + borrow::Borrow, collections::{hash_map, BTreeMap, HashMap, HashSet}, fmt::Write, sync::{Arc, RwLock as StdRwLock}, @@ -773,6 +774,7 @@ impl Service { Ok(pdu_id) } + #[tracing::instrument(skip_all, name = "resolve")] pub async fn resolve_state( &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, ) -> Result>> { @@ -793,14 +795,17 @@ impl Service { let fork_states = [current_state_ids, incoming_state]; let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) - .await? - .collect::>>() - .await, - ); + let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); + + let auth_chain = self + .services + .auth_chain + .event_ids_iter(room_id, &starting_events) + .await? + .collect::>>() + .await; + + auth_chain_sets.push(auth_chain); } debug!("Loading fork states"); @@ -962,12 +967,11 @@ impl Service { let mut state = StateMap::with_capacity(leaf_state.len()); let mut starting_events = Vec::with_capacity(leaf_state.len()); - - for (k, id) in leaf_state { + for (k, id) in &leaf_state { if let Ok((ty, st_key)) = self .services .short - .get_statekey_from_short(k) + .get_statekey_from_short(*k) .await .log_err() { @@ -976,18 +980,18 @@ impl Service { state.insert((ty.to_string().into(), st_key), id.clone()); } - starting_events.push(id); + starting_events.push(id.borrow()); } - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, starting_events) - .await? - .collect() - .await, - ); + let auth_chain = self + .services + .auth_chain + .event_ids_iter(room_id, &starting_events) + .await? + .collect() + .await; + auth_chain_sets.push(auth_chain); fork_states.push(state); } From 4496cf2d5b08780fd2d2b32c31c2c0b38bf010e7 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 21 Sep 2024 16:28:46 -0400 Subject: [PATCH 015/245] add missing await to first admin room creation Signed-off-by: strawberry --- src/service/admin/create.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index 7b090aa0b..3dd5aea35 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -30,7 +30,11 @@ use crate::Services; pub async fn create_admin_room(services: &Services) -> Result<()> { let room_id = RoomId::new(services.globals.server_name()); - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id); + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; From 5192927a5342cffd9a7284bad3eb2c4b4819c674 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 29 Sep 2024 07:37:43 +0000 Subject: [PATCH 016/245] split remaining map suites Signed-off-by: Jason Volk --- src/core/utils/mod.rs | 10 -- src/database/map.rs | 163 ++--------------------------- src/database/map/get.rs | 82 +++++++++++++++ src/database/map/insert.rs | 52 +++++++++ src/database/map/remove.rs | 44 ++++++++ src/service/globals/data.rs | 28 +++-- src/service/rooms/short/data.rs | 2 +- src/service/rooms/timeline/data.rs | 2 +- src/service/users/mod.rs | 2 +- 9 files changed, 205 insertions(+), 180 deletions(-) create mode 100644 src/database/map/get.rs create mode 100644 src/database/map/insert.rs create mode 100644 src/database/map/remove.rs diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index b1ea3709d..fef833954 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -35,13 +35,3 @@ pub use self::{ #[inline] pub fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } - -#[must_use] -pub fn generate_keypair() -> Vec { - let mut value = rand::string(8).as_bytes().to_vec(); - value.push(0xFF); - value.extend_from_slice( - &ruma::signatures::Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"), - ); - value -} diff --git a/src/database/map.rs b/src/database/map.rs index a3cf32d4e..cac20d6a6 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -1,7 +1,10 @@ mod count; +mod get; +mod insert; mod keys; mod keys_from; mod keys_prefix; +mod remove; mod rev_keys; mod rev_keys_from; mod rev_keys_prefix; @@ -18,23 +21,14 @@ use std::{ fmt, fmt::{Debug, Display}, future::Future, - io::Write, pin::Pin, sync::Arc, }; -use conduit::{err, Result}; -use futures::future; -use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, WriteBatchWithTransaction, WriteOptions}; -use serde::Serialize; +use conduit::Result; +use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, WriteOptions}; -use crate::{ - keyval::{OwnedKey, OwnedVal}, - ser, - util::{map_err, or_else}, - watchers::Watchers, - Engine, Handle, -}; +use crate::{watchers::Watchers, Engine}; pub struct Map { name: String, @@ -57,146 +51,6 @@ impl Map { })) } - #[tracing::instrument(skip(self), fields(%self), level = "trace")] - pub fn del(&self, key: &K) - where - K: Serialize + ?Sized + Debug, - { - let mut buf = Vec::::with_capacity(64); - self.bdel(key, &mut buf); - } - - #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] - pub fn bdel(&self, key: &K, buf: &mut B) - where - K: Serialize + ?Sized + Debug, - B: Write + AsRef<[u8]>, - { - let key = ser::serialize(buf, key).expect("failed to serialize deletion key"); - self.remove(&key); - } - - #[tracing::instrument(level = "trace")] - pub fn remove(&self, key: &K) - where - K: AsRef<[u8]> + ?Sized + Debug, - { - let write_options = &self.write_options; - self.db - .db - .delete_cf_opt(&self.cf(), key, write_options) - .or_else(or_else) - .expect("database remove error"); - - if !self.db.corked() { - self.db.flush().expect("database flush error"); - } - } - - #[tracing::instrument(skip(self, value), fields(%self), level = "trace")] - pub fn insert(&self, key: &K, value: &V) - where - K: AsRef<[u8]> + ?Sized + Debug, - V: AsRef<[u8]> + ?Sized, - { - let write_options = &self.write_options; - self.db - .db - .put_cf_opt(&self.cf(), key, value, write_options) - .or_else(or_else) - .expect("database insert error"); - - if !self.db.corked() { - self.db.flush().expect("database flush error"); - } - - self.watchers.wake(key.as_ref()); - } - - #[tracing::instrument(skip(self), fields(%self), level = "trace")] - pub fn insert_batch<'a, I, K, V>(&'a self, iter: I) - where - I: Iterator + Send + Debug, - K: AsRef<[u8]> + Sized + Debug + 'a, - V: AsRef<[u8]> + Sized + 'a, - { - let mut batch = WriteBatchWithTransaction::::default(); - for (key, val) in iter { - batch.put_cf(&self.cf(), key.as_ref(), val.as_ref()); - } - - let write_options = &self.write_options; - self.db - .db - .write_opt(batch, write_options) - .or_else(or_else) - .expect("database insert batch error"); - - if !self.db.corked() { - self.db.flush().expect("database flush error"); - } - } - - #[tracing::instrument(skip(self), fields(%self), level = "trace")] - pub fn qry(&self, key: &K) -> impl Future>> + Send - where - K: Serialize + ?Sized + Debug, - { - let mut buf = Vec::::with_capacity(64); - self.bqry(key, &mut buf) - } - - #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] - pub fn bqry(&self, key: &K, buf: &mut B) -> impl Future>> + Send - where - K: Serialize + ?Sized + Debug, - B: Write + AsRef<[u8]>, - { - let key = ser::serialize(buf, key).expect("failed to serialize query key"); - let val = self.get(key); - future::ready(val) - } - - #[tracing::instrument(skip(self), fields(%self), level = "trace")] - pub fn get(&self, key: &K) -> Result> - where - K: AsRef<[u8]> + ?Sized + Debug, - { - self.db - .db - .get_pinned_cf_opt(&self.cf(), key, &self.read_options) - .map_err(map_err)? - .map(Handle::from) - .ok_or(err!(Request(NotFound("Not found in database")))) - } - - #[tracing::instrument(skip(self), fields(%self), level = "trace")] - pub fn multi_get<'a, I, K>(&self, keys: I) -> Vec> - where - I: Iterator + ExactSizeIterator + Send + Debug, - K: AsRef<[u8]> + Sized + Debug + 'a, - { - // Optimization can be `true` if key vector is pre-sorted **by the column - // comparator**. - const SORTED: bool = false; - - let mut ret: Vec> = Vec::with_capacity(keys.len()); - let read_options = &self.read_options; - for res in self - .db - .db - .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) - { - match res { - Ok(Some(res)) => ret.push(Some((*res).to_vec())), - Ok(None) => ret.push(None), - Err(e) => or_else(e).expect("database multiget error"), - } - } - - ret - } - #[inline] pub fn watch_prefix<'a, K>(&'a self, prefix: &K) -> Pin + Send + 'a>> where @@ -230,10 +84,7 @@ fn open(db: &Arc, name: &str) -> Result> { let bounded_ptr = Arc::into_raw(bounded_arc); let cf_ptr = bounded_ptr.cast::(); - // SAFETY: After thorough contemplation this appears to be the best solution, - // even by a significant margin. - // - // BACKGROUND: Column family handles out of RocksDB are basic pointers and can + // SAFETY: Column family handles out of RocksDB are basic pointers and can // be invalidated: 1. when the database closes. 2. when the column is dropped or // closed. rust_rocksdb wraps this for us by storing handles in their own // `RwLock` map and returning an Arc>` to diff --git a/src/database/map/get.rs b/src/database/map/get.rs new file mode 100644 index 000000000..b4d6a6ea8 --- /dev/null +++ b/src/database/map/get.rs @@ -0,0 +1,82 @@ +use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; + +use conduit::{err, implement, Result}; +use futures::future::ready; +use serde::Serialize; + +use crate::{ + keyval::{OwnedKey, OwnedVal}, + ser, + util::{map_err, or_else}, + Handle, +}; + +#[implement(super::Map)] +pub fn qry(&self, key: &K) -> impl Future>> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = Vec::::with_capacity(64); + self.bqry(key, &mut buf) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bqry(&self, key: &K, buf: &mut B) -> impl Future>> + Send +where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize query key"); + self.get(key) +} + +#[implement(super::Map)] +pub fn get(&self, key: &K) -> impl Future>> + Send +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + ready(self.get_blocking(key)) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn get_blocking(&self, key: &K) -> Result> +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + self.db + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options) + .map_err(map_err)? + .map(Handle::from) + .ok_or(err!(Request(NotFound("Not found in database")))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] +pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec> +where + I: Iterator + ExactSizeIterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, +{ + // Optimization can be `true` if key vector is pre-sorted **by the column + // comparator**. + const SORTED: bool = false; + + let mut ret: Vec> = Vec::with_capacity(keys.len()); + let read_options = &self.read_options; + for res in self + .db + .db + .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) + { + match res { + Ok(Some(res)) => ret.push(Some((*res).to_vec())), + Ok(None) => ret.push(None), + Err(e) => or_else(e).expect("database multiget error"), + } + } + + ret +} diff --git a/src/database/map/insert.rs b/src/database/map/insert.rs new file mode 100644 index 000000000..953c9c94c --- /dev/null +++ b/src/database/map/insert.rs @@ -0,0 +1,52 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::implement; +use rocksdb::WriteBatchWithTransaction; + +use crate::util::or_else; + +#[implement(super::Map)] +#[tracing::instrument(skip(self, value), fields(%self), level = "trace")] +pub fn insert(&self, key: &K, value: &V) +where + K: AsRef<[u8]> + ?Sized + Debug, + V: AsRef<[u8]> + ?Sized, +{ + let write_options = &self.write_options; + self.db + .db + .put_cf_opt(&self.cf(), key, value, write_options) + .or_else(or_else) + .expect("database insert error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } + + self.watchers.wake(key.as_ref()); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, iter), fields(%self), level = "trace")] +pub fn insert_batch<'a, I, K, V>(&'a self, iter: I) +where + I: Iterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, + V: AsRef<[u8]> + Sized + 'a, +{ + let mut batch = WriteBatchWithTransaction::::default(); + for (key, val) in iter { + batch.put_cf(&self.cf(), key.as_ref(), val.as_ref()); + } + + let write_options = &self.write_options; + self.db + .db + .write_opt(batch, write_options) + .or_else(or_else) + .expect("database insert batch error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } +} diff --git a/src/database/map/remove.rs b/src/database/map/remove.rs new file mode 100644 index 000000000..fcf7587e0 --- /dev/null +++ b/src/database/map/remove.rs @@ -0,0 +1,44 @@ +use std::{convert::AsRef, fmt::Debug, io::Write}; + +use conduit::implement; +use serde::Serialize; + +use crate::{ser, util::or_else}; + +#[implement(super::Map)] +pub fn del(&self, key: &K) +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = Vec::::with_capacity(64); + self.bdel(key, &mut buf); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bdel(&self, key: &K, buf: &mut B) +where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize deletion key"); + self.remove(key); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn remove(&self, key: &K) +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + let write_options = &self.write_options; + self.db + .db + .delete_cf_opt(&self.cf(), key, write_options) + .or_else(or_else) + .expect("database remove error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } +} diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 76f979441..5332f07d1 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, RwLock}, }; -use conduit::{trace, utils, Error, Result, Server}; +use conduit::{trace, utils, utils::rand, Error, Result, Server}; use database::{Database, Deserialized, Map}; use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ @@ -102,7 +102,7 @@ impl Data { fn stored_count(global: &Arc) -> Result { global - .get(COUNTER) + .get_blocking(COUNTER) .as_deref() .map_or(Ok(0_u64), utils::u64_from_bytes) } @@ -206,17 +206,23 @@ impl Data { } pub fn load_keypair(&self) -> Result { - let keypair_bytes = self.global.get(b"keypair").map_or_else( - |_| { - let keypair = utils::generate_keypair(); - self.global.insert(b"keypair", &keypair); - Ok::<_, Error>(keypair) - }, - |val| Ok(val.to_vec()), - )?; + let generate = |_| { + let keypair = Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"); - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); + let mut value = rand::string(8).as_bytes().to_vec(); + value.push(0xFF); + value.extend_from_slice(&keypair); + + self.global.insert(b"keypair", &value); + value + }; + + let keypair_bytes: Vec = self + .global + .get_blocking(b"keypair") + .map_or_else(generate, Into::into); + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); utils::string_from_bytes( // 1. version parts diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index f6a824883..fff3f2d62 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -59,7 +59,7 @@ impl Data { for (i, short) in self .eventid_shorteventid - .multi_get(keys.iter()) + .get_batch_blocking(keys.iter()) .iter() .enumerate() { diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 314dcb9fd..1f9dad1dc 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -326,7 +326,7 @@ pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount { //TODO: this is an ABA fn increment(db: &Arc, key: &[u8]) { - let old = db.get(key); + let old = db.get_blocking(key); let new = utils::increment(old.ok().as_deref()); db.insert(key, &new); } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index ca37ed9dc..fa8c41b6b 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1000,7 +1000,7 @@ where //TODO: this is an ABA fn increment(db: &Arc, key: &[u8]) { - let old = db.get(key); + let old = db.get_blocking(key); let new = utils::increment(old.ok().as_deref()); db.insert(key, &new); } From 0e8ae1e13e601c572336e38d1b020eb6a6aafe0d Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 29 Sep 2024 12:49:24 +0000 Subject: [PATCH 017/245] add ArrayVec-backed serialized query overload; doc comments Signed-off-by: Jason Volk --- Cargo.lock | 1 + src/database/Cargo.toml | 1 + src/database/de.rs | 2 +- src/database/handle.rs | 28 ++++++++++++++++------------ src/database/map/get.rs | 23 +++++++++++++++++++++++ src/database/map/remove.rs | 10 ++++++++++ 6 files changed, 52 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 08e0498aa..043d9704b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -726,6 +726,7 @@ dependencies = [ name = "conduit_database" version = "0.4.7" dependencies = [ + "arrayvec", "conduit_core", "const-str", "futures", diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index b5eb76126..0e718aa71 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -35,6 +35,7 @@ zstd_compression = [ ] [dependencies] +arrayvec.workspace = true conduit-core.workspace = true const-str.workspace = true futures.workspace = true diff --git a/src/database/de.rs b/src/database/de.rs index a5d2c1272..fc36560d6 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -195,7 +195,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_u8>(self, _visitor: V) -> Result { - unimplemented!("deserialize u8 not implemented") + unimplemented!("deserialize u8 not implemented; try dereferencing the Handle for [u8] access instead") } fn deserialize_u16>(self, _visitor: V) -> Result { diff --git a/src/database/handle.rs b/src/database/handle.rs index 0d4bd02ea..daee224d4 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -35,18 +35,6 @@ impl Serialize for Handle<'_> { } } -impl Deref for Handle<'_> { - type Target = Slice; - - #[inline] - fn deref(&self) -> &Self::Target { &self.val } -} - -impl AsRef for Handle<'_> { - #[inline] - fn as_ref(&self) -> &Slice { &self.val } -} - impl Deserialized for Result> { #[inline] fn map_de(self, f: F) -> Result @@ -78,3 +66,19 @@ impl<'a> Deserialized for &'a Handle<'a> { deserialize_val(self.as_ref()).map(f) } } + +impl From> for Vec { + fn from(handle: Handle<'_>) -> Self { handle.deref().to_vec() } +} + +impl Deref for Handle<'_> { + type Target = Slice; + + #[inline] + fn deref(&self) -> &Self::Target { &self.val } +} + +impl AsRef for Handle<'_> { + #[inline] + fn as_ref(&self) -> &Slice { &self.val } +} diff --git a/src/database/map/get.rs b/src/database/map/get.rs index b4d6a6ea8..71489402c 100644 --- a/src/database/map/get.rs +++ b/src/database/map/get.rs @@ -1,5 +1,6 @@ use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; +use arrayvec::ArrayVec; use conduit::{err, implement, Result}; use futures::future::ready; use serde::Serialize; @@ -11,6 +12,9 @@ use crate::{ Handle, }; +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into an allocated buffer to perform +/// the query. #[implement(super::Map)] pub fn qry(&self, key: &K) -> impl Future>> + Send where @@ -20,6 +24,20 @@ where self.bqry(key, &mut buf) } +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into a fixed-sized buffer to perform +/// the query. The maximum size is supplied as const generic parameter. +#[implement(super::Map)] +pub fn aqry(&self, key: &K) -> impl Future>> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = ArrayVec::::new(); + self.bqry(key, &mut buf) +} + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into a user-supplied Writer. #[implement(super::Map)] #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] pub fn bqry(&self, key: &K, buf: &mut B) -> impl Future>> + Send @@ -31,6 +49,8 @@ where self.get(key) } +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is referenced directly to perform the query. #[implement(super::Map)] pub fn get(&self, key: &K) -> impl Future>> + Send where @@ -39,6 +59,9 @@ where ready(self.get_blocking(key)) } +/// Fetch a value from the database into cache, returning a reference-handle. +/// The key is referenced directly to perform the query. This is a thread- +/// blocking call. #[implement(super::Map)] #[tracing::instrument(skip(self, key), fields(%self), level = "trace")] pub fn get_blocking(&self, key: &K) -> Result> diff --git a/src/database/map/remove.rs b/src/database/map/remove.rs index fcf7587e0..10bb2ff01 100644 --- a/src/database/map/remove.rs +++ b/src/database/map/remove.rs @@ -1,5 +1,6 @@ use std::{convert::AsRef, fmt::Debug, io::Write}; +use arrayvec::ArrayVec; use conduit::implement; use serde::Serialize; @@ -14,6 +15,15 @@ where self.bdel(key, &mut buf); } +#[implement(super::Map)] +pub fn adel(&self, key: &K) +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = ArrayVec::::new(); + self.bdel(key, &mut buf); +} + #[implement(super::Map)] #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] pub fn bdel(&self, key: &K, buf: &mut B) From c569881b0853245dea0f8704342d6cfa6c465edb Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 29 Sep 2024 13:13:09 +0000 Subject: [PATCH 018/245] merge rooms/short Data w/ Service; optimize queries Signed-off-by: Jason Volk --- src/service/account_data/mod.rs | 2 +- src/service/appservice/data.rs | 2 +- src/service/globals/data.rs | 12 +- src/service/globals/migrations.rs | 6 +- src/service/media/data.rs | 2 +- src/service/presence/data.rs | 6 +- src/service/rooms/alias/mod.rs | 6 +- src/service/rooms/directory/mod.rs | 2 +- src/service/rooms/outlier/mod.rs | 4 +- src/service/rooms/short/data.rs | 167 ---------------- src/service/rooms/short/mod.rs | 214 ++++++++++++++++++--- src/service/rooms/state/data.rs | 2 +- src/service/rooms/state_accessor/data.rs | 4 +- src/service/rooms/state_cache/mod.rs | 8 +- src/service/rooms/state_compressor/data.rs | 4 +- src/service/rooms/timeline/data.rs | 22 +-- src/service/sending/data.rs | 6 +- src/service/users/mod.rs | 32 +-- 18 files changed, 248 insertions(+), 253 deletions(-) delete mode 100644 src/service/rooms/short/data.rs diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 4f00cff1c..482229e7f 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -106,7 +106,7 @@ pub async fn get( self.db .roomusertype_roomuserdataid .qry(&key) - .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.qry(&roomuserdataid)) + .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.get(&roomuserdataid)) .await .deserialized() } diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index f31c5e636..4eb9d09e5 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -38,7 +38,7 @@ impl Data { pub async fn get_registration(&self, id: &str) -> Result { self.id_appserviceregistrations - .qry(id) + .get(id) .await .deserialized() .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5332f07d1..57a295d99 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -260,7 +260,7 @@ impl Data { &self, origin: &ServerName, new_keys: ServerSigningKeys, ) -> BTreeMap { // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.qry(origin).await; + let signingkeys = self.server_signingkeys.get(origin).await; let mut keys = signingkeys .and_then(|keys| serde_json::from_slice(&keys).map_err(Into::into)) @@ -311,10 +311,16 @@ impl Data { } pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { - self.server_signingkeys.qry(origin).await.deserialized() + self.server_signingkeys.get(origin).await.deserialized() } - pub async fn database_version(&self) -> u64 { self.global.qry("version").await.deserialized().unwrap_or(0) } + pub async fn database_version(&self) -> u64 { + self.global + .get(b"version") + .await + .deserialized() + .unwrap_or(0) + } #[inline] pub fn bump_database_version(&self, new_version: u64) -> Result<()> { diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index c7a732309..469159fc7 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -99,14 +99,14 @@ async fn migrate(services: &Services) -> Result<()> { db_lt_13(services).await?; } - if db["global"].qry("feat_sha256_media").await.is_not_found() { + if db["global"].get(b"feat_sha256_media").await.is_not_found() { media::migrations::migrate_sha256_media(services).await?; } else if config.media_startup_check { media::migrations::checkup_sha256_media(services).await?; } if db["global"] - .qry("fix_bad_double_separator_in_state_cache") + .get(b"fix_bad_double_separator_in_state_cache") .await .is_not_found() { @@ -114,7 +114,7 @@ async fn migrate(services: &Services) -> Result<()> { } if db["global"] - .qry("retroactively_fix_bad_data_from_roomuserid_joined") + .get(b"retroactively_fix_bad_data_from_roomuserid_joined") .await .is_not_found() { diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 29d562cc3..248e9e1d2 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -253,7 +253,7 @@ impl Data { } pub(super) async fn get_url_preview(&self, url: &str) -> Result { - let values = self.url_previews.qry(url).await?; + let values = self.url_previews.get(url).await?; let mut values = values.split(|&b| b == 0xFF); diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index 0c3f3d31d..9c9d0ae3f 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -39,12 +39,12 @@ impl Data { pub async fn get_presence(&self, user_id: &UserId) -> Result<(u64, PresenceEvent)> { let count = self .userid_presenceid - .qry(user_id) + .get(user_id) .await .deserialized::()?; let key = presenceid_key(count, user_id); - let bytes = self.presenceid_presence.qry(&key).await?; + let bytes = self.presenceid_presence.get(&key).await?; let event = Presence::from_json_bytes(&bytes)? .to_presence_event(user_id, &self.services.users) .await; @@ -127,7 +127,7 @@ impl Data { pub(super) async fn remove_presence(&self, user_id: &UserId) { let Ok(count) = self .userid_presenceid - .qry(user_id) + .get(user_id) .await .deserialized::() else { diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 6b81a221a..1d44cd2d8 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -94,7 +94,7 @@ impl Service { } let alias = alias.alias(); - let Ok(room_id) = self.db.alias_roomid.qry(&alias).await else { + let Ok(room_id) = self.db.alias_roomid.get(&alias).await else { return Err!(Request(NotFound("Alias does not exist or is invalid."))); }; @@ -151,7 +151,7 @@ impl Service { #[tracing::instrument(skip(self), level = "debug")] pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result { - self.db.alias_roomid.qry(alias.alias()).await.deserialized() + self.db.alias_roomid.get(alias.alias()).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] @@ -219,7 +219,7 @@ impl Service { } async fn who_created_alias(&self, alias: &RoomAliasId) -> Result { - self.db.alias_userid.qry(alias.alias()).await.deserialized() + self.db.alias_userid.get(alias.alias()).await.deserialized() } async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result> { diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 3585205d3..5666a91a7 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -32,7 +32,7 @@ pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_i pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id.as_bytes()); } #[implement(Service)] -pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.qry(room_id).await.is_ok() } +pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.get(room_id).await.is_ok() } #[implement(Service)] pub fn public_rooms(&self) -> impl Stream + Send { diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 4c9225ae8..b9d042638 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -31,7 +31,7 @@ impl crate::Service for Service { pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result { self.db .eventid_outlierpdu - .qry(event_id) + .get(event_id) .await .deserialized() } @@ -41,7 +41,7 @@ pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result Result { self.db .eventid_outlierpdu - .qry(event_id) + .get(event_id) .await .deserialized() } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs deleted file mode 100644 index fff3f2d62..000000000 --- a/src/service/rooms/short/data.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::sync::Arc; - -use conduit::{err, utils, Error, Result}; -use database::{Deserialized, Map}; -use ruma::{events::StateEventType, EventId, RoomId}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - eventid_shorteventid: Arc, - shorteventid_eventid: Arc, - statekey_shortstatekey: Arc, - shortstatekey_statekey: Arc, - roomid_shortroomid: Arc, - statehash_shortstatehash: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - eventid_shorteventid: db["eventid_shorteventid"].clone(), - shorteventid_eventid: db["shorteventid_eventid"].clone(), - statekey_shortstatekey: db["statekey_shortstatekey"].clone(), - shortstatekey_statekey: db["shortstatekey_statekey"].clone(), - roomid_shortroomid: db["roomid_shortroomid"].clone(), - statehash_shortstatehash: db["statehash_shortstatehash"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { - if let Ok(shorteventid) = self.eventid_shorteventid.qry(event_id).await.deserialized() { - return shorteventid; - } - - let shorteventid = self.services.globals.next_count().unwrap(); - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes()); - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes()); - - shorteventid - } - - pub(super) async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { - let mut ret: Vec = Vec::with_capacity(event_ids.len()); - let keys = event_ids - .iter() - .map(|id| id.as_bytes()) - .collect::>(); - - for (i, short) in self - .eventid_shorteventid - .get_batch_blocking(keys.iter()) - .iter() - .enumerate() - { - #[allow(clippy::single_match_else)] - match short { - Some(short) => ret.push( - utils::u64_from_bytes(short) - .map_err(|_| Error::bad_database("Invalid shorteventid in db.")) - .unwrap(), - ), - None => { - let short = self.services.globals.next_count().unwrap(); - self.eventid_shorteventid - .insert(keys[i], &short.to_be_bytes()); - self.shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i]); - - debug_assert!(ret.len() == i, "position of result must match input"); - ret.push(short); - }, - } - } - - ret - } - - pub(super) async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - let key = (event_type, state_key); - self.statekey_shortstatekey.qry(&key).await.deserialized() - } - - pub(super) async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { - let key = (event_type.to_string(), state_key); - if let Ok(shortstatekey) = self.statekey_shortstatekey.qry(&key).await.deserialized() { - return shortstatekey; - } - - let mut key = event_type.to_string().as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(state_key.as_bytes()); - - let shortstatekey = self.services.globals.next_count().unwrap(); - self.statekey_shortstatekey - .insert(&key, &shortstatekey.to_be_bytes()); - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &key); - - shortstatekey - } - - pub(super) async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - self.shorteventid_eventid - .qry(&shorteventid) - .await - .deserialized() - .map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}"))) - } - - pub(super) async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - self.shortstatekey_statekey - .qry(&shortstatekey) - .await - .deserialized() - .map_err(|e| { - err!(Database( - "Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}" - )) - }) - } - - /// Returns (shortstatehash, already_existed) - pub(super) async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { - if let Ok(shortstatehash) = self - .statehash_shortstatehash - .qry(state_hash) - .await - .deserialized() - { - return (shortstatehash, true); - } - - let shortstatehash = self.services.globals.next_count().unwrap(); - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes()); - - (shortstatehash, false) - } - - pub(super) async fn get_shortroomid(&self, room_id: &RoomId) -> Result { - self.roomid_shortroomid.qry(room_id).await.deserialized() - } - - pub(super) async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { - self.roomid_shortroomid - .qry(room_id) - .await - .deserialized() - .unwrap_or_else(|_| { - let short = self.services.globals.next_count().unwrap(); - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes()); - short - }) - } -} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 00bb7cb13..66da39485 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,61 +1,215 @@ -mod data; - use std::sync::Arc; -use conduit::Result; +use conduit::{err, implement, utils, Error, Result}; +use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; -use self::data::Data; +use crate::{globals, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + eventid_shorteventid: Arc, + shorteventid_eventid: Arc, + statekey_shortstatekey: Arc, + shortstatekey_statekey: Arc, + roomid_shortroomid: Arc, + statehash_shortstatehash: Arc, +} + +struct Services { + globals: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + eventid_shorteventid: args.db["eventid_shorteventid"].clone(), + shorteventid_eventid: args.db["shorteventid_eventid"].clone(), + statekey_shortstatekey: args.db["statekey_shortstatekey"].clone(), + shortstatekey_statekey: args.db["shortstatekey_statekey"].clone(), + roomid_shortroomid: args.db["roomid_shortroomid"].clone(), + statehash_shortstatehash: args.db["statehash_shortstatehash"].clone(), + }, + services: Services { + globals: args.depend::("globals"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { - self.db.get_or_create_shorteventid(event_id).await +#[implement(Service)] +pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { + if let Ok(shorteventid) = self + .db + .eventid_shorteventid + .get(event_id) + .await + .deserialized() + { + return shorteventid; } - pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { - self.db.multi_get_or_create_shorteventid(event_ids).await - } + let shorteventid = self.services.globals.next_count().unwrap(); + self.db + .eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes()); + self.db + .shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes()); - pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - self.db.get_shortstatekey(event_type, state_key).await - } + shorteventid +} - pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { - self.db - .get_or_create_shortstatekey(event_type, state_key) - .await - } +#[implement(Service)] +pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { + let mut ret: Vec = Vec::with_capacity(event_ids.len()); + let keys = event_ids + .iter() + .map(|id| id.as_bytes()) + .collect::>(); - pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - self.db.get_eventid_from_short(shorteventid).await - } + for (i, short) in self + .db + .eventid_shorteventid + .get_batch_blocking(keys.iter()) + .iter() + .enumerate() + { + match short { + Some(short) => ret.push( + utils::u64_from_bytes(short) + .map_err(|_| Error::bad_database("Invalid shorteventid in db.")) + .unwrap(), + ), + None => { + let short = self.services.globals.next_count().unwrap(); + self.db + .eventid_shorteventid + .insert(keys[i], &short.to_be_bytes()); + self.db + .shorteventid_eventid + .insert(&short.to_be_bytes(), keys[i]); - pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - self.db.get_statekey_from_short(shortstatekey).await + debug_assert!(ret.len() == i, "position of result must match input"); + ret.push(short); + }, + } } - /// Returns (shortstatehash, already_existed) - pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { - self.db.get_or_create_shortstatehash(state_hash).await + ret +} + +#[implement(Service)] +pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + let key = (event_type, state_key); + self.db + .statekey_shortstatekey + .qry(&key) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { + let key = (event_type.to_string(), state_key); + if let Ok(shortstatekey) = self + .db + .statekey_shortstatekey + .qry(&key) + .await + .deserialized() + { + return shortstatekey; } - pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { self.db.get_shortroomid(room_id).await } + let mut key = event_type.to_string().as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(state_key.as_bytes()); + + let shortstatekey = self.services.globals.next_count().unwrap(); + self.db + .statekey_shortstatekey + .insert(&key, &shortstatekey.to_be_bytes()); + self.db + .shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &key); - pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { - self.db.get_or_create_shortroomid(room_id).await + shortstatekey +} + +#[implement(Service)] +pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + const BUFSIZE: usize = size_of::(); + + self.db + .shorteventid_eventid + .aqry::(&shorteventid) + .await + .deserialized() + .map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}"))) +} + +#[implement(Service)] +pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + const BUFSIZE: usize = size_of::(); + + self.db + .shortstatekey_statekey + .aqry::(&shortstatekey) + .await + .deserialized() + .map_err(|e| { + err!(Database( + "Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}" + )) + }) +} + +/// Returns (shortstatehash, already_existed) +#[implement(Service)] +pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { + if let Ok(shortstatehash) = self + .db + .statehash_shortstatehash + .get(state_hash) + .await + .deserialized() + { + return (shortstatehash, true); } + + let shortstatehash = self.services.globals.next_count().unwrap(); + self.db + .statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes()); + + (shortstatehash, false) +} + +#[implement(Service)] +pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { + self.db.roomid_shortroomid.qry(room_id).await.deserialized() +} + +#[implement(Service)] +pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { + self.db + .roomid_shortroomid + .get(room_id) + .await + .deserialized() + .unwrap_or_else(|_| { + let short = self.services.globals.next_count().unwrap(); + self.db + .roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes()); + short + }) } diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index ccf7509a8..3072e3c65 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -25,7 +25,7 @@ impl Data { } pub(super) async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { - self.roomid_shortstatehash.qry(room_id).await.deserialized() + self.roomid_shortstatehash.get(room_id).await.deserialized() } #[inline] diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 79a983257..adc26f000 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -157,8 +157,8 @@ impl Data { /// Returns the state hash for this pdu. pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { self.eventid_shorteventid - .qry(event_id) - .and_then(|shorteventid| self.shorteventid_shortstatehash.qry(&shorteventid)) + .get(event_id) + .and_then(|shorteventid| self.shorteventid_shortstatehash.get(&shorteventid)) .await .deserialized() } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index ce5b024b7..eedff8612 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -435,10 +435,10 @@ impl Service { /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_joined(&self, user_id: &UserId) -> impl Stream + Send { + pub fn rooms_joined<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { self.db .userroomid_joined - .keys_prefix(user_id) + .keys_prefix_raw(user_id) .ignore_err() .map(|(_, room_id): (Ignore, &RoomId)| room_id) } @@ -494,10 +494,10 @@ impl Service { } #[tracing::instrument(skip(self), level = "debug")] - pub fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> impl Stream + Send + 'a { + pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { self.db .roomid_inviteviaservers - .stream_prefix(room_id) + .stream_prefix_raw(room_id) .ignore_err() .map(|(_, servers): (Ignore, Vec<&ServerName>)| &**(servers.last().expect("at least one servername"))) } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 9a9f70a28..cb0204705 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -23,9 +23,11 @@ impl Data { } pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result { + const BUFSIZE: usize = size_of::(); + let value = self .shortstatehash_statediff - .qry(&shortstatehash) + .aqry::(&shortstatehash) .await .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 1f9dad1dc..cb85cf19c 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -79,7 +79,7 @@ impl Data { /// Returns the `count` of this pdu's id. pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result { self.eventid_pduid - .qry(event_id) + .get(event_id) .await .map(|pdu_id| pdu_count(&pdu_id)) } @@ -90,27 +90,27 @@ impl Data { return Ok(pdu); } - self.eventid_outlierpdu.qry(event_id).await.deserialized() + self.eventid_outlierpdu.get(event_id).await.deserialized() } /// Returns the json of a pdu. pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.qry(&pduid).await.deserialized() + self.pduid_pdu.get(&pduid).await.deserialized() } /// Returns the pdu's id. #[inline] pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result> { - self.eventid_pduid.qry(event_id).await + self.eventid_pduid.get(event_id).await } /// Returns the pdu directly from `eventid_pduid` only. pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.qry(&pduid).await.deserialized() + self.pduid_pdu.get(&pduid).await.deserialized() } /// Like get_non_outlier_pdu(), but without the expense of fetching and @@ -118,7 +118,7 @@ impl Data { pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.qry(&pduid).await?; + self.pduid_pdu.get(&pduid).await?; Ok(()) } @@ -132,7 +132,7 @@ impl Data { } self.eventid_outlierpdu - .qry(event_id) + .get(event_id) .await .deserialized() .map(Arc::new) @@ -141,7 +141,7 @@ impl Data { /// Like get_non_outlier_pdu(), but without the expense of fetching and /// parsing the PduEvent pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { - self.eventid_outlierpdu.qry(event_id).await?; + self.eventid_outlierpdu.get(event_id).await?; Ok(()) } @@ -159,12 +159,12 @@ impl Data { /// /// This does __NOT__ check the outliers `Tree`. pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { - self.pduid_pdu.qry(pdu_id).await.deserialized() + self.pduid_pdu.get(pdu_id).await.deserialized() } /// Returns the pdu as a `BTreeMap`. pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { - self.pduid_pdu.qry(pdu_id).await.deserialized() + self.pduid_pdu.get(pdu_id).await.deserialized() } pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { @@ -196,7 +196,7 @@ impl Data { pub(super) async fn replace_pdu( &self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, ) -> Result<()> { - if self.pduid_pdu.qry(pdu_id).await.is_not_found() { + if self.pduid_pdu.get(pdu_id).await.is_not_found() { return Err!(Request(NotFound("PDU does not exist."))); } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index b96f9a03c..6f4b5b970 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -98,7 +98,7 @@ impl Data { } #[inline] - pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> impl Stream + Send + 'a { + pub fn active_requests_for(&self, destination: &Destination) -> impl Stream + Send + '_ { let prefix = destination.get_prefix(); self.servercurrentevent_data .stream_raw_prefix(&prefix) @@ -133,7 +133,7 @@ impl Data { keys } - pub fn queued_requests<'a>(&'a self, destination: &Destination) -> impl Stream + Send + 'a { + pub fn queued_requests(&self, destination: &Destination) -> impl Stream + Send + '_ { let prefix = destination.get_prefix(); self.servernameevent_data .stream_raw_prefix(&prefix) @@ -152,7 +152,7 @@ impl Data { pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 { self.servername_educount - .qry(server_name) + .get(server_name) .await .deserialized() .unwrap_or(0) diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index fa8c41b6b..eb77ef357 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -120,13 +120,13 @@ impl Service { /// Check if a user has an account on this homeserver. #[inline] - pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.qry(user_id).await.is_ok() } + pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.get(user_id).await.is_ok() } /// Check if account is deactivated pub async fn is_deactivated(&self, user_id: &UserId) -> Result { self.db .userid_password - .qry(user_id) + .get(user_id) .map_ok(|val| val.is_empty()) .map_err(|_| err!(Request(NotFound("User does not exist.")))) .await @@ -146,7 +146,7 @@ impl Service { /// Find out which user an access token belongs to. pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> { - self.db.token_userdeviceid.qry(token).await.deserialized() + self.db.token_userdeviceid.get(token).await.deserialized() } /// Returns an iterator over all users on this homeserver (offered for @@ -171,7 +171,7 @@ impl Service { /// Returns the password hash for the given user. pub async fn password_hash(&self, user_id: &UserId) -> Result { - self.db.userid_password.qry(user_id).await.deserialized() + self.db.userid_password.get(user_id).await.deserialized() } /// Hash and set the user's password to the Argon2 hash @@ -196,7 +196,7 @@ impl Service { /// Returns the displayname of a user on this homeserver. pub async fn displayname(&self, user_id: &UserId) -> Result { - self.db.userid_displayname.qry(user_id).await.deserialized() + self.db.userid_displayname.get(user_id).await.deserialized() } /// Sets a new displayname or removes it if displayname is None. You still @@ -213,7 +213,7 @@ impl Service { /// Get the `avatar_url` of a user. pub async fn avatar_url(&self, user_id: &UserId) -> Result { - self.db.userid_avatarurl.qry(user_id).await.deserialized() + self.db.userid_avatarurl.get(user_id).await.deserialized() } /// Sets a new avatar_url or removes it if avatar_url is None. @@ -229,7 +229,7 @@ impl Service { /// Get the blurhash of a user. pub async fn blurhash(&self, user_id: &UserId) -> Result { - self.db.userid_blurhash.qry(user_id).await.deserialized() + self.db.userid_blurhash.get(user_id).await.deserialized() } /// Sets a new avatar_url or removes it if avatar_url is None. @@ -284,7 +284,7 @@ impl Service { userdeviceid.extend_from_slice(device_id.as_bytes()); // Remove tokens - if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await { + if let Ok(old_token) = self.db.userdeviceid_token.get(&userdeviceid).await { self.db.userdeviceid_token.remove(&userdeviceid); self.db.token_userdeviceid.remove(&old_token); } @@ -390,7 +390,7 @@ impl Service { pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { self.db .userid_lastonetimekeyupdate - .qry(user_id) + .get(user_id) .await .deserialized() .unwrap_or(0) @@ -664,7 +664,7 @@ impl Service { let key = self .db .keyid_key - .qry(key_id) + .get(key_id) .await .deserialized::()?; @@ -679,7 +679,7 @@ impl Service { where F: Fn(&UserId) -> bool + Send + Sync, { - let key_id = self.db.userid_masterkeyid.qry(user_id).await?; + let key_id = self.db.userid_masterkeyid.get(user_id).await?; self.get_key(&key_id, sender_user, user_id, allowed_signatures) .await @@ -691,16 +691,16 @@ impl Service { where F: Fn(&UserId) -> bool + Send + Sync, { - let key_id = self.db.userid_selfsigningkeyid.qry(user_id).await?; + let key_id = self.db.userid_selfsigningkeyid.get(user_id).await?; self.get_key(&key_id, sender_user, user_id, allowed_signatures) .await } pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { - let key_id = self.db.userid_usersigningkeyid.qry(user_id).await?; + let key_id = self.db.userid_usersigningkeyid.get(user_id).await?; - self.db.keyid_key.qry(&*key_id).await.deserialized() + self.db.keyid_key.get(&*key_id).await.deserialized() } pub async fn add_to_device_event( @@ -797,7 +797,7 @@ impl Service { pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { self.db .userid_devicelistversion - .qry(user_id) + .get(user_id) .await .deserialized() } @@ -853,7 +853,7 @@ impl Service { /// Find out which user an OpenID access token belongs to. pub async fn find_from_openid_token(&self, token: &str) -> Result { - let Ok(value) = self.db.openidtoken_expiresatuserid.qry(token).await else { + let Ok(value) = self.db.openidtoken_expiresatuserid.get(token).await else { return Err!(Request(Unauthorized("OpenID token is unrecognised"))); }; From a8d5cf96517d706cff5fe73650405edaff4e0779 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 23 Sep 2024 21:38:56 +0000 Subject: [PATCH 019/245] Add rocksdb logging integration with tracing. Signed-off-by: Jason Volk --- src/database/engine.rs | 17 ++++++++++++++++- src/database/opts.rs | 2 ++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/database/engine.rs b/src/database/engine.rs index 067232e67..edf077fc9 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -10,7 +10,7 @@ use conduit::{debug, error, info, utils::time::rfc2822_from_seconds, warn, Err, use rocksdb::{ backup::{BackupEngine, BackupEngineOptions}, perf::get_memory_usage_stats, - AsColumnFamilyRef, BoundColumnFamily, Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode, Env, + AsColumnFamilyRef, BoundColumnFamily, Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode, Env, LogLevel, MultiThreaded, Options, }; @@ -279,6 +279,21 @@ pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result<()> { Ok(()) } +#[tracing::instrument(skip_all, name = "rocksdb")] +pub(crate) fn handle_log(level: LogLevel, msg: &str) { + let msg = msg.trim(); + if msg.starts_with("Options") { + return; + } + + match level { + LogLevel::Header | LogLevel::Debug => debug!("{msg}"), + LogLevel::Error | LogLevel::Fatal => error!("{msg}"), + LogLevel::Info => debug!("{msg}"), + LogLevel::Warn => warn!("{msg}"), + }; +} + impl Drop for Engine { #[cold] fn drop(&mut self) { diff --git a/src/database/opts.rs b/src/database/opts.rs index d2ad4b95c..46fb4c542 100644 --- a/src/database/opts.rs +++ b/src/database/opts.rs @@ -191,6 +191,8 @@ fn set_logging_defaults(opts: &mut Options, config: &Config) { if config.rocksdb_log_stderr { opts.set_stderr_logger(rocksdb_log_level, "rocksdb"); + } else { + opts.set_callback_logger(rocksdb_log_level, &super::engine::handle_log); } } From 6b80361c31fc8b2eeeafbcfbf14a463c3423ee7c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 30 Sep 2024 06:46:54 +0000 Subject: [PATCH 020/245] additional stream tools Signed-off-by: Jason Volk --- src/core/result.rs | 5 +- src/core/result/into_is_ok.rs | 10 +++ src/core/utils/bool.rs | 16 +++++ src/core/utils/future/mod.rs | 3 + src/core/utils/future/try_ext_ext.rs | 48 +++++++++++++ src/core/utils/mod.rs | 6 +- src/core/utils/stream/mod.rs | 2 + src/core/utils/stream/ready.rs | 102 ++++++++++++++++++--------- src/core/utils/stream/tools.rs | 80 +++++++++++++++++++++ src/service/rooms/state_cache/mod.rs | 16 ++--- 10 files changed, 242 insertions(+), 46 deletions(-) create mode 100644 src/core/result/into_is_ok.rs create mode 100644 src/core/utils/bool.rs create mode 100644 src/core/utils/future/mod.rs create mode 100644 src/core/utils/future/try_ext_ext.rs create mode 100644 src/core/utils/stream/tools.rs diff --git a/src/core/result.rs b/src/core/result.rs index 96a34b8a3..82d67a9c5 100644 --- a/src/core/result.rs +++ b/src/core/result.rs @@ -1,4 +1,5 @@ mod debug_inspect; +mod into_is_ok; mod log_debug_err; mod log_err; mod map_expect; @@ -6,8 +7,8 @@ mod not_found; mod unwrap_infallible; pub use self::{ - debug_inspect::DebugInspect, log_debug_err::LogDebugErr, log_err::LogErr, map_expect::MapExpect, - not_found::NotFound, unwrap_infallible::UnwrapInfallible, + debug_inspect::DebugInspect, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, log_err::LogErr, + map_expect::MapExpect, not_found::NotFound, unwrap_infallible::UnwrapInfallible, }; pub type Result = std::result::Result; diff --git a/src/core/result/into_is_ok.rs b/src/core/result/into_is_ok.rs new file mode 100644 index 000000000..220ce010c --- /dev/null +++ b/src/core/result/into_is_ok.rs @@ -0,0 +1,10 @@ +use super::Result; + +pub trait IntoIsOk { + fn into_is_ok(self) -> bool; +} + +impl IntoIsOk for Result { + #[inline] + fn into_is_ok(self) -> bool { self.is_ok() } +} diff --git a/src/core/utils/bool.rs b/src/core/utils/bool.rs new file mode 100644 index 000000000..d7ce78fe3 --- /dev/null +++ b/src/core/utils/bool.rs @@ -0,0 +1,16 @@ +//! Trait BoolExt + +/// Boolean extensions and chain.starters +pub trait BoolExt { + fn or T>(self, f: F) -> Option; + + fn or_some(self, t: T) -> Option; +} + +impl BoolExt for bool { + #[inline] + fn or T>(self, f: F) -> Option { (!self).then(f) } + + #[inline] + fn or_some(self, t: T) -> Option { (!self).then_some(t) } +} diff --git a/src/core/utils/future/mod.rs b/src/core/utils/future/mod.rs new file mode 100644 index 000000000..6d45b6563 --- /dev/null +++ b/src/core/utils/future/mod.rs @@ -0,0 +1,3 @@ +mod try_ext_ext; + +pub use try_ext_ext::TryExtExt; diff --git a/src/core/utils/future/try_ext_ext.rs b/src/core/utils/future/try_ext_ext.rs new file mode 100644 index 000000000..e444ad94a --- /dev/null +++ b/src/core/utils/future/try_ext_ext.rs @@ -0,0 +1,48 @@ +//! Extended external extensions to futures::TryFutureExt + +use futures::{future::MapOkOrElse, TryFuture, TryFutureExt}; + +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait TryExtExt +where + Self: TryFuture + Send, +{ + fn map_ok_or( + self, default: U, f: F, + ) -> MapOkOrElse U, impl FnOnce(Self::Error) -> U> + where + F: FnOnce(Self::Ok) -> U, + Self: Send + Sized; + + fn ok( + self, + ) -> MapOkOrElse Option, impl FnOnce(Self::Error) -> Option> + where + Self: Sized; +} + +impl TryExtExt for Fut +where + Fut: TryFuture + Send, +{ + #[inline] + fn map_ok_or( + self, default: U, f: F, + ) -> MapOkOrElse U, impl FnOnce(Self::Error) -> U> + where + F: FnOnce(Self::Ok) -> U, + Self: Send + Sized, + { + self.map_ok_or_else(|_| default, f) + } + + #[inline] + fn ok( + self, + ) -> MapOkOrElse Option, impl FnOnce(Self::Error) -> Option> + where + Self: Sized, + { + self.map_ok_or(None, Some) + } +} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index fef833954..c34691d2d 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,7 +1,9 @@ +pub mod bool; pub mod bytes; pub mod content_disposition; pub mod debug; pub mod defer; +pub mod future; pub mod hash; pub mod html; pub mod json; @@ -19,15 +21,17 @@ pub use ::conduit_macros::implement; pub use ::ctor::{ctor, dtor}; pub use self::{ + bool::BoolExt, bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}, debug::slice_truncated as debug_slice_truncated, + future::TryExtExt as TryFutureExtExt, hash::calculate_hash, html::Escape as HtmlEscape, json::{deserialize_from_str, to_canonical_object}, math::clamp, mutex_map::{Guard as MutexMapGuard, MutexMap}, rand::string as random_string, - stream::{IterStream, ReadyExt, TryReadyExt}, + stream::{IterStream, ReadyExt, Tools as StreamTools, TryReadyExt}, string::{str_from_bytes, string_from_bytes}, sys::available_parallelism, time::now_millis as millis_since_unix_epoch, diff --git a/src/core/utils/stream/mod.rs b/src/core/utils/stream/mod.rs index 781bd5223..1111915b3 100644 --- a/src/core/utils/stream/mod.rs +++ b/src/core/utils/stream/mod.rs @@ -3,6 +3,7 @@ mod expect; mod ignore; mod iter_stream; mod ready; +mod tools; mod try_ready; pub use cloned::Cloned; @@ -10,4 +11,5 @@ pub use expect::TryExpect; pub use ignore::TryIgnore; pub use iter_stream::IterStream; pub use ready::ReadyExt; +pub use tools::Tools; pub use try_ready::TryReadyExt; diff --git a/src/core/utils/stream/ready.rs b/src/core/utils/stream/ready.rs index 13f730a7d..da5aec5a6 100644 --- a/src/core/utils/stream/ready.rs +++ b/src/core/utils/stream/ready.rs @@ -2,7 +2,7 @@ use futures::{ future::{ready, Ready}, - stream::{Any, Filter, FilterMap, Fold, ForEach, SkipWhile, Stream, StreamExt, TakeWhile}, + stream::{Any, Filter, FilterMap, Fold, ForEach, Scan, SkipWhile, Stream, StreamExt, TakeWhile}, }; /// Synchronous combinators to augment futures::StreamExt. Most Stream @@ -11,98 +11,130 @@ use futures::{ /// convenience to reduce boilerplate by de-cluttering non-async predicates. /// /// This interface is not necessarily complete; feel free to add as-needed. -pub trait ReadyExt +pub trait ReadyExt where - S: Stream + Send + ?Sized, - Self: Stream + Send + Sized, + Self: Stream + Send + Sized, { - fn ready_any(self, f: F) -> Any, impl FnMut(S::Item) -> Ready> + fn ready_any(self, f: F) -> Any, impl FnMut(Item) -> Ready> where - F: Fn(S::Item) -> bool; + F: Fn(Item) -> bool; - fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&S::Item) -> Ready + 'a> + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&Item) -> Ready + 'a> where - F: Fn(&S::Item) -> bool + 'a; + F: Fn(&Item) -> bool + 'a; - fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(S::Item) -> Ready>> + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(Item) -> Ready>> where - F: Fn(S::Item) -> Option; + F: Fn(Item) -> Option; - fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, S::Item) -> Ready> + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, Item) -> Ready> where - F: Fn(T, S::Item) -> T; + F: Fn(T, Item) -> T; - fn ready_for_each(self, f: F) -> ForEach, impl FnMut(S::Item) -> Ready<()>> + fn ready_for_each(self, f: F) -> ForEach, impl FnMut(Item) -> Ready<()>> where - F: FnMut(S::Item); + F: FnMut(Item); - fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&S::Item) -> Ready + 'a> + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&Item) -> Ready + 'a> where - F: Fn(&S::Item) -> bool + 'a; + F: Fn(&Item) -> bool + 'a; - fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&S::Item) -> Ready + 'a> + fn ready_scan( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> where - F: Fn(&S::Item) -> bool + 'a; + F: Fn(&mut T, Item) -> Option; + + fn ready_scan_each( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> + where + F: Fn(&mut T, &Item); + + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&Item) -> Ready + 'a> + where + F: Fn(&Item) -> bool + 'a; } -impl ReadyExt for S +impl ReadyExt for S where - S: Stream + Send + ?Sized, - Self: Stream + Send + Sized, + S: Stream + Send + Sized, { #[inline] - fn ready_any(self, f: F) -> Any, impl FnMut(S::Item) -> Ready> + fn ready_any(self, f: F) -> Any, impl FnMut(Item) -> Ready> where - F: Fn(S::Item) -> bool, + F: Fn(Item) -> bool, { self.any(move |t| ready(f(t))) } #[inline] - fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&S::Item) -> Ready + 'a> + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&Item) -> Ready + 'a> where - F: Fn(&S::Item) -> bool + 'a, + F: Fn(&Item) -> bool + 'a, { self.filter(move |t| ready(f(t))) } #[inline] - fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(S::Item) -> Ready>> + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(Item) -> Ready>> where - F: Fn(S::Item) -> Option, + F: Fn(Item) -> Option, { self.filter_map(move |t| ready(f(t))) } #[inline] - fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, S::Item) -> Ready> + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, Item) -> Ready> where - F: Fn(T, S::Item) -> T, + F: Fn(T, Item) -> T, { self.fold(init, move |a, t| ready(f(a, t))) } #[inline] #[allow(clippy::unit_arg)] - fn ready_for_each(self, mut f: F) -> ForEach, impl FnMut(S::Item) -> Ready<()>> + fn ready_for_each(self, mut f: F) -> ForEach, impl FnMut(Item) -> Ready<()>> where - F: FnMut(S::Item), + F: FnMut(Item), { self.for_each(move |t| ready(f(t))) } #[inline] - fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&S::Item) -> Ready + 'a> + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&Item) -> Ready + 'a> where - F: Fn(&S::Item) -> bool + 'a, + F: Fn(&Item) -> bool + 'a, { self.take_while(move |t| ready(f(t))) } #[inline] - fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&S::Item) -> Ready + 'a> + fn ready_scan( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> + where + F: Fn(&mut T, Item) -> Option, + { + self.scan(init, move |s, t| ready(f(s, t))) + } + + fn ready_scan_each( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> + where + F: Fn(&mut T, &Item), + { + self.ready_scan(init, move |s, t| { + f(s, &t); + Some(t) + }) + } + + #[inline] + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&Item) -> Ready + 'a> where - F: Fn(&S::Item) -> bool + 'a, + F: Fn(&Item) -> bool + 'a, { self.skip_while(move |t| ready(f(t))) } diff --git a/src/core/utils/stream/tools.rs b/src/core/utils/stream/tools.rs new file mode 100644 index 000000000..cc6b7ca9e --- /dev/null +++ b/src/core/utils/stream/tools.rs @@ -0,0 +1,80 @@ +//! StreamTools for futures::Stream + +use std::{collections::HashMap, hash::Hash}; + +use futures::{Future, Stream, StreamExt}; + +use super::ReadyExt; +use crate::expected; + +/// StreamTools +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait Tools +where + Self: Stream + Send + Sized, + ::Item: Send, +{ + fn counts(self) -> impl Future> + Send + where + ::Item: Eq + Hash; + + fn counts_by(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send; + + fn counts_by_with_cap(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send; + + fn counts_with_cap(self) -> impl Future> + Send + where + ::Item: Eq + Hash; +} + +impl Tools for S +where + S: Stream + Send + Sized, + ::Item: Send, +{ + #[inline] + fn counts(self) -> impl Future> + Send + where + ::Item: Eq + Hash, + { + self.counts_with_cap::<0>() + } + + #[inline] + fn counts_by(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send, + { + self.counts_by_with_cap::<0, K, F>(f) + } + + #[inline] + fn counts_by_with_cap(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send, + { + self.map(f).counts_with_cap::() + } + + #[inline] + fn counts_with_cap(self) -> impl Future> + Send + where + ::Item: Eq + Hash, + { + self.ready_fold(HashMap::with_capacity(CAP), |mut counts, item| { + let entry = counts.entry(item).or_default(); + let value = *entry; + *entry = expected!(value + 1); + counts + }) + } +} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index eedff8612..253880849 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -4,7 +4,7 @@ use std::{collections::HashSet, sync::Arc}; use conduit::{ err, - utils::{stream::TryIgnore, ReadyExt}, + utils::{stream::TryIgnore, ReadyExt, StreamTools}, warn, Result, }; use data::Data; @@ -495,11 +495,13 @@ impl Service { #[tracing::instrument(skip(self), level = "debug")] pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Ignore, Vec<&'a ServerName>); + self.db .roomid_inviteviaservers .stream_prefix_raw(room_id) .ignore_err() - .map(|(_, servers): (Ignore, Vec<&ServerName>)| &**(servers.last().expect("at least one servername"))) + .map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server")) } /// Gets up to three servers that are likely to be in the room in the @@ -525,16 +527,14 @@ impl Service { let mut servers: Vec = self .room_members(room_id) - .collect::>() - .await - .iter() .counts_by(|user| user.server_name().to_owned()) - .iter() + .await + .into_iter() .sorted_by_key(|(_, users)| *users) - .map(|(server, _)| server.to_owned()) + .map(|(server, _)| server) .rev() .take(3) - .collect_vec(); + .collect(); if let Some(server) = most_powerful_user_server { servers.insert(0, server); From 96fcf7f94d65e93fdcb23acd3f52945813dbc18e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 1 Oct 2024 04:20:31 +0000 Subject: [PATCH 021/245] add rocksdb secondary; fix read_only mode. Signed-off-by: Jason Volk --- src/core/config/mod.rs | 3 +++ src/database/database.rs | 8 ++++++++ src/database/engine.rs | 11 +++++++++-- src/service/emergency/mod.rs | 4 ++++ src/service/globals/mod.rs | 3 +++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index d2d583a8c..d8e1c7d93 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -236,6 +236,8 @@ pub struct Config { #[serde(default)] pub rocksdb_read_only: bool, #[serde(default)] + pub rocksdb_secondary: bool, + #[serde(default)] pub rocksdb_compaction_prio_idle: bool, #[serde(default = "true_fn")] pub rocksdb_compaction_ioprio_idle: bool, @@ -752,6 +754,7 @@ impl fmt::Display for Config { line("RocksDB Recovery Mode", &self.rocksdb_recovery_mode.to_string()); line("RocksDB Repair Mode", &self.rocksdb_repair.to_string()); line("RocksDB Read-only Mode", &self.rocksdb_read_only.to_string()); + line("RocksDB Secondary Mode", &self.rocksdb_secondary.to_string()); line( "RocksDB Compaction Idle Priority", &self.rocksdb_compaction_prio_idle.to_string(), diff --git a/src/database/database.rs b/src/database/database.rs index ac6f62e90..4c29c840c 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -38,6 +38,14 @@ impl Database { #[inline] pub fn iter_maps(&self) -> impl Iterator + Send + '_ { self.map.iter() } + + #[inline] + #[must_use] + pub fn is_read_only(&self) -> bool { self.db.secondary || self.db.read_only } + + #[inline] + #[must_use] + pub fn is_secondary(&self) -> bool { self.db.secondary } } impl Index<&str> for Database { diff --git a/src/database/engine.rs b/src/database/engine.rs index edf077fc9..99d971ed6 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -28,6 +28,8 @@ pub struct Engine { cfs: Mutex>, pub(crate) db: Db, corks: AtomicU32, + pub(super) read_only: bool, + pub(super) secondary: bool, } pub(crate) type Db = DBWithThreadMode; @@ -80,10 +82,13 @@ impl Engine { .collect::>(); debug!("Opening database..."); + let path = &config.database_path; let res = if config.rocksdb_read_only { - Db::open_cf_for_read_only(&db_opts, &config.database_path, cfs.clone(), false) + Db::open_cf_descriptors_read_only(&db_opts, path, cfds, false) + } else if config.rocksdb_secondary { + Db::open_cf_descriptors_as_secondary(&db_opts, path, path, cfds) } else { - Db::open_cf_descriptors(&db_opts, &config.database_path, cfds) + Db::open_cf_descriptors(&db_opts, path, cfds) }; let db = res.or_else(or_else)?; @@ -103,6 +108,8 @@ impl Engine { cfs: Mutex::new(cfs), db, corks: AtomicU32::new(0), + read_only: config.rocksdb_read_only, + secondary: config.rocksdb_secondary, })) } diff --git a/src/service/emergency/mod.rs b/src/service/emergency/mod.rs index 98020bc29..c99a0891e 100644 --- a/src/service/emergency/mod.rs +++ b/src/service/emergency/mod.rs @@ -32,6 +32,10 @@ impl crate::Service for Service { } async fn worker(self: Arc) -> Result<()> { + if self.services.globals.is_read_only() { + return Ok(()); + } + self.set_emergency_access() .await .inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index f777901f6..f24e8a274 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -329,4 +329,7 @@ impl Service { #[inline] pub fn server_is_ours(&self, server_name: &ServerName) -> bool { server_name == self.config.server_name } + + #[inline] + pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() } } From 26dcab272d04eff968997a94f90636df389ffda6 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 1 Oct 2024 02:47:39 +0000 Subject: [PATCH 022/245] various cleanup tweaks/fixes Signed-off-by: Jason Volk --- Cargo.lock | 1 + src/admin/query/room_alias.rs | 10 ++++- src/admin/room/alias.rs | 4 +- src/admin/room/directory.rs | 8 ++-- src/admin/room/info.rs | 10 ++--- src/admin/room/moderation.rs | 4 +- src/api/client/keys.rs | 2 +- src/api/client/membership.rs | 15 ++++--- src/api/client/search.rs | 4 +- src/api/client/sync.rs | 67 +++++++++++++--------------- src/api/client/user_directory.rs | 11 ++--- src/core/Cargo.toml | 1 + src/core/error/mod.rs | 2 + src/service/appservice/data.rs | 4 +- src/service/globals/migrations.rs | 11 ++--- src/service/rooms/state_cache/mod.rs | 28 +++--------- src/service/rooms/timeline/mod.rs | 1 - src/service/users/mod.rs | 2 + 18 files changed, 86 insertions(+), 99 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 043d9704b..065aa1e4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -709,6 +709,7 @@ dependencies = [ "serde", "serde_json", "serde_regex", + "serde_yaml", "thiserror", "tikv-jemalloc-ctl", "tikv-jemalloc-sys", diff --git a/src/admin/query/room_alias.rs b/src/admin/query/room_alias.rs index 05fac42cc..382e4a784 100644 --- a/src/admin/query/room_alias.rs +++ b/src/admin/query/room_alias.rs @@ -43,8 +43,13 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.local_aliases_for_room(&room_id); - let aliases: Vec<_> = results.collect().await; + let aliases: Vec<_> = services + .rooms + .alias + .local_aliases_for_room(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -57,6 +62,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) .rooms .alias .all_local_aliases() + .map(|(room_id, alias)| (room_id.to_owned(), alias.to_owned())) .collect::>() .await; let query_time = timer.elapsed(); diff --git a/src/admin/room/alias.rs b/src/admin/room/alias.rs index 34b6c42ec..1ccde47dc 100644 --- a/src/admin/room/alias.rs +++ b/src/admin/room/alias.rs @@ -119,12 +119,12 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> room_id, } => { if let Some(room_id) = room_id { - let aliases = services + let aliases: Vec = services .rooms .alias .local_aliases_for_room(&room_id) .map(Into::into) - .collect::>() + .collect() .await; let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { diff --git a/src/admin/room/directory.rs b/src/admin/room/directory.rs index 7ccdea6f0..1080356a8 100644 --- a/src/admin/room/directory.rs +++ b/src/admin/room/directory.rs @@ -47,22 +47,22 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> } => { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); - let mut rooms = services + let mut rooms: Vec<_> = services .rooms .directory .public_rooms() .then(|room_id| get_room_info(services, room_id)) - .collect::>() + .collect() .await; rooms.sort_by_key(|r| r.1); rooms.reverse(); - let rooms = rooms + let rooms: Vec<_> = rooms .into_iter() .skip(page.saturating_sub(1).saturating_mul(PAGE_SIZE)) .take(PAGE_SIZE) - .collect::>(); + .collect(); if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("No more rooms.")); diff --git a/src/admin/room/info.rs b/src/admin/room/info.rs index fc0619e33..13a74a9d3 100644 --- a/src/admin/room/info.rs +++ b/src/admin/room/info.rs @@ -42,14 +42,12 @@ async fn list_joined_members(&self, room_id: Box, local_only: bool) -> R .state_cache .room_members(&room_id) .ready_filter(|user_id| { - if local_only { - self.services.globals.user_is_local(user_id) - } else { - true - } + local_only + .then(|| self.services.globals.user_is_local(user_id)) + .unwrap_or(true) }) + .map(ToOwned::to_owned) .filter_map(|user_id| async move { - let user_id = user_id.to_owned(); Some(( self.services .users diff --git a/src/admin/room/moderation.rs b/src/admin/room/moderation.rs index 9a772da48..cfc048bdd 100644 --- a/src/admin/room/moderation.rs +++ b/src/admin/room/moderation.rs @@ -555,13 +555,13 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> #[admin_command] async fn list_banned_rooms(&self, no_details: bool) -> Result { - let room_ids = self + let room_ids: Vec = self .services .rooms .metadata .list_banned_rooms() .map(Into::into) - .collect::>() + .collect() .await; if room_ids.is_empty() { diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index abf2a22f5..254d92ccd 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -244,7 +244,7 @@ pub(crate) async fn get_key_changes_route( device_list_updates.extend( services .users - .keys_changed(room_id.as_ref(), from, Some(to)) + .keys_changed(room_id.as_str(), from, Some(to)) .map(ToOwned::to_owned) .collect::>() .await, diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 5a5d436f1..6e3bc8940 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -167,12 +167,12 @@ pub(crate) async fn join_room_by_id_route( .await?; // There is no body.server_name for /roomId/join - let mut servers = services + let mut servers: Vec<_> = services .rooms .state_cache .servers_invite_via(&body.room_id) .map(ToOwned::to_owned) - .collect::>() + .collect() .await; servers.extend( @@ -641,12 +641,13 @@ pub(crate) async fn joined_members_route( .rooms .state_cache .room_members(&body.room_id) + .map(ToOwned::to_owned) .then(|user| async move { ( - user.to_owned(), + user.clone(), RoomMember { - display_name: services.users.displayname(user).await.ok(), - avatar_url: services.users.avatar_url(user).await.ok(), + display_name: services.users.displayname(&user).await.ok(), + avatar_url: services.users.avatar_url(&user).await.ok(), }, ) }) @@ -1575,7 +1576,7 @@ pub(crate) async fn invite_helper( // Make a user leave all their joined rooms, forgets all rooms, and ignores // errors pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { - let all_rooms = services + let all_rooms: Vec<_> = services .rooms .state_cache .rooms_joined(user_id) @@ -1587,7 +1588,7 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { .rooms_invited(user_id) .map(|(r, _)| r), ) - .collect::>() + .collect() .await; for room_id in all_rooms { diff --git a/src/api/client/search.rs b/src/api/client/search.rs index 7a061d494..b073640e8 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -77,14 +77,14 @@ pub(crate) async fn search_events_route( .user_can_see_state_events(sender_user, room_id) .await { - let room_state = services + let room_state: Vec<_> = services .rooms .state_accessor .room_state_full(room_id) .await? .values() .map(|pdu| pdu.to_state_event()) - .collect::>(); + .collect(); debug!("Room state: {:?}", room_state); diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 53d4f3c35..adb4d8da7 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -7,13 +7,14 @@ use std::{ use axum::extract::State; use conduit::{ debug, err, error, is_equal_to, + result::IntoIsOk, utils::{ math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, - IterStream, ReadyExt, + BoolExt, IterStream, ReadyExt, TryFutureExtExt, }, warn, PduCount, }; -use futures::{pin_mut, StreamExt}; +use futures::{pin_mut, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -172,12 +173,12 @@ pub(crate) async fn sync_events_route( process_presence_updates(&services, &mut presence_updates, since, &sender_user).await?; } - let all_joined_rooms = services + let all_joined_rooms: Vec<_> = services .rooms .state_cache .rooms_joined(&sender_user) .map(ToOwned::to_owned) - .collect::>() + .collect() .await; // Coalesce database writes for the remainder of this scope. @@ -869,15 +870,13 @@ async fn load_joined_room( .rooms .state_cache .room_members(room_id) - .ready_filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != *user_id - }) - .filter_map(|user_id| async move { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - (!share_encrypted_room(services, sender_user, user_id, Some(room_id)).await) - .then_some(user_id.to_owned()) + // Don't send key updates from the sender to the sender + .ready_filter(|user_id| sender_user != *user_id) + // Only send keys if the sender doesn't share an encrypted room with the target + // already + .filter_map(|user_id| { + share_encrypted_room(services, sender_user, user_id, Some(room_id)) + .map(|res| res.or_some(user_id.to_owned())) }) .collect::>() .await, @@ -1117,13 +1116,12 @@ async fn share_encrypted_room( .user .get_shared_rooms(sender_user, user_id) .ready_filter(|&room_id| Some(room_id) != ignore_room) - .any(|other_room_id| async move { + .any(|other_room_id| { services .rooms .state_accessor .room_state_get(other_room_id, &StateEventType::RoomEncryption, "") - .await - .is_ok() + .map(Result::into_is_ok) }) .await } @@ -1178,20 +1176,20 @@ pub(crate) async fn sync_events_v4_route( .sync .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); - let all_joined_rooms = services + let all_joined_rooms: Vec<_> = services .rooms .state_cache .rooms_joined(&sender_user) .map(ToOwned::to_owned) - .collect::>() + .collect() .await; - let all_invited_rooms = services + let all_invited_rooms: Vec<_> = services .rooms .state_cache .rooms_invited(&sender_user) .map(|r| r.0) - .collect::>() + .collect() .await; let all_rooms = all_joined_rooms @@ -1364,15 +1362,13 @@ pub(crate) async fn sync_events_v4_route( .rooms .state_cache .room_members(room_id) - .ready_filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter_map(|user_id| async move { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - (!share_encrypted_room(&services, sender_user, user_id, Some(room_id)).await) - .then_some(user_id.to_owned()) + // Don't send key updates from the sender to the sender + .ready_filter(|user_id| sender_user != user_id) + // Only send keys if the sender doesn't share an encrypted room with the target + // already + .filter_map(|user_id| { + share_encrypted_room(&services, sender_user, user_id, Some(room_id)) + .map(|res| res.or_some(user_id.to_owned())) }) .collect::>() .await, @@ -1650,26 +1646,25 @@ pub(crate) async fn sync_events_v4_route( .await; // Heroes - let heroes = services + let heroes: Vec<_> = services .rooms .state_cache .room_members(room_id) .ready_filter(|member| member != &sender_user) - .filter_map(|member| async move { + .filter_map(|user_id| { services .rooms .state_accessor - .get_member(room_id, member) - .await - .map(|memberevent| SlidingSyncRoomHero { - user_id: member.to_owned(), + .get_member(room_id, user_id) + .map_ok(|memberevent| SlidingSyncRoomHero { + user_id: user_id.into(), name: memberevent.displayname, avatar: memberevent.avatar_url, }) .ok() }) .take(5) - .collect::>() + .collect() .await; let name = match heroes.len().cmp(&(1_usize)) { diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index 8ea7f1b82..868811a3f 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::utils::TryFutureExtExt; use futures::{pin_mut, StreamExt}; use ruma::{ api::client::user_directory::search_users, @@ -56,16 +57,12 @@ pub(crate) async fn search_users_route( .rooms .state_cache .rooms_joined(&user.user_id) - .any(|room| async move { + .any(|room| { services .rooms .state_accessor - .room_state_get(room, &StateEventType::RoomJoinRules, "") - .await - .map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) - }) + .room_state_get_content::(room, &StateEventType::RoomJoinRules, "") + .map_ok_or(false, |content| content.join_rule == JoinRule::Public) }) .await; diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index cb957bc90..4fe413e93 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -83,6 +83,7 @@ ruma.workspace = true sanitize-filename.workspace = true serde_json.workspace = true serde_regex.workspace = true +serde_yaml.workspace = true serde.workspace = true thiserror.workspace = true tikv-jemallocator.optional = true diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 79e3d5b40..ad7f9f3ca 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -75,6 +75,8 @@ pub enum Error { TracingFilter(#[from] tracing_subscriber::filter::ParseError), #[error("Tracing reload error: {0}")] TracingReload(#[from] tracing_subscriber::reload::Error), + #[error(transparent)] + Yaml(#[from] serde_yaml::Error), // ruma/conduwuit #[error("Arithmetic operation failed: {0}")] diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 4eb9d09e5..8fb7d9582 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use conduit::{err, utils::stream::TryIgnore, Result}; -use database::{Database, Deserialized, Map}; +use database::{Database, Map}; use futures::Stream; use ruma::api::appservice::Registration; @@ -40,7 +40,7 @@ impl Data { self.id_appserviceregistrations .get(id) .await - .deserialized() + .and_then(|ref bytes| serde_yaml::from_slice(bytes).map_err(Into::into)) .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 469159fc7..fc6e477b3 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -9,7 +9,7 @@ use itertools::Itertools; use ruma::{ events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, push::Ruleset, - UserId, + OwnedUserId, UserId, }; use crate::{media, Services}; @@ -385,11 +385,12 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) for room_id in &room_ids { debug_info!("Fixing room {room_id}"); - let users_in_room = services + let users_in_room: Vec = services .rooms .state_cache .room_members(room_id) - .collect::>() + .map(ToOwned::to_owned) + .collect() .await; let joined_members = users_in_room @@ -418,12 +419,12 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) .collect::>() .await; - for user_id in joined_members { + for user_id in &joined_members { debug_info!("User is joined, marking as joined"); services.rooms.state_cache.mark_as_joined(user_id, room_id); } - for user_id in non_joined_members { + for user_id in &non_joined_members { debug_info!("User is left or banned, marking as left"); services.rooms.state_cache.mark_as_left(user_id, room_id); } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 253880849..dbe385619 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -648,35 +648,19 @@ impl Service { self.db.userroomid_leftstate.remove(&userroom_id); self.db.roomuserid_leftcount.remove(&roomuser_id); - if let Some(servers) = invite_via { - let mut prev_servers = self - .servers_invite_via(room_id) - .map(ToOwned::to_owned) - .collect::>() - .await; - #[allow(clippy::redundant_clone)] // this is a necessary clone? - prev_servers.append(servers.clone().as_mut()); - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.db - .roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers); + if let Some(servers) = invite_via.as_deref() { + self.add_servers_invite_via(room_id, servers).await; } } - #[tracing::instrument(skip(self), level = "debug")] + #[tracing::instrument(skip(self, servers), level = "debug")] pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) { - let mut prev_servers = self + let mut prev_servers: Vec<_> = self .servers_invite_via(room_id) .map(ToOwned::to_owned) - .collect::>() + .collect() .await; + prev_servers.extend(servers.to_owned()); prev_servers.sort_unstable(); prev_servers.dedup(); diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 5360d2c96..6a26a1d53 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -408,7 +408,6 @@ impl Service { .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into()) .await .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) - .map_err(|e| err!(Database(warn!(?user, ?e, "Invalid push rules event in db for user")))) .map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); let mut highlight = false; diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index eb77ef357..438c220bc 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -623,7 +623,9 @@ impl Service { pub async fn mark_device_key_update(&self, user_id: &UserId) { let count = self.services.globals.next_count().unwrap().to_be_bytes(); + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + pin_mut!(rooms_joined); while let Some(room_id) = rooms_joined.next().await { // Don't send key updates to unencrypted rooms From ab06701ed08862bef04bf06800dbf021bd317497 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 1 Oct 2024 22:37:01 +0000 Subject: [PATCH 023/245] refactor multi-get to handle result type Signed-off-by: Jason Volk --- src/database/keyval.rs | 8 ------- src/database/map/get.rs | 43 ++++++++++++++++------------------ src/service/rooms/short/mod.rs | 39 ++++++++++-------------------- 3 files changed, 32 insertions(+), 58 deletions(-) diff --git a/src/database/keyval.rs b/src/database/keyval.rs index c9d25977d..a288f1842 100644 --- a/src/database/keyval.rs +++ b/src/database/keyval.rs @@ -3,10 +3,6 @@ use serde::Deserialize; use crate::de; -pub(crate) type OwnedKeyVal = (Vec, Vec); -pub(crate) type OwnedKey = Vec; -pub(crate) type OwnedVal = Vec; - pub type KeyVal<'a, K = &'a Slice, V = &'a Slice> = (Key<'a, K>, Val<'a, V>); pub type Key<'a, T = &'a Slice> = T; pub type Val<'a, T = &'a Slice> = T; @@ -72,10 +68,6 @@ where de::from_slice::(val) } -#[inline] -#[must_use] -pub fn to_owned(kv: KeyVal<'_>) -> OwnedKeyVal { (kv.0.to_owned(), kv.1.to_owned()) } - #[inline] pub fn key(kv: KeyVal<'_, K, V>) -> Key<'_, K> { kv.0 } diff --git a/src/database/map/get.rs b/src/database/map/get.rs index 71489402c..72382e367 100644 --- a/src/database/map/get.rs +++ b/src/database/map/get.rs @@ -3,14 +3,12 @@ use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; use arrayvec::ArrayVec; use conduit::{err, implement, Result}; use futures::future::ready; +use rocksdb::DBPinnableSlice; use serde::Serialize; -use crate::{ - keyval::{OwnedKey, OwnedVal}, - ser, - util::{map_err, or_else}, - Handle, -}; +use crate::{ser, util, Handle}; + +type RocksdbResult<'a> = Result>, rocksdb::Error>; /// Fetch a value from the database into cache, returning a reference-handle /// asynchronously. The key is serialized into an allocated buffer to perform @@ -68,17 +66,17 @@ pub fn get_blocking(&self, key: &K) -> Result> where K: AsRef<[u8]> + ?Sized + Debug, { - self.db + let res = self .db - .get_pinned_cf_opt(&self.cf(), key, &self.read_options) - .map_err(map_err)? - .map(Handle::from) - .ok_or(err!(Request(NotFound("Not found in database")))) + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options); + + into_result_handle(res) } #[implement(super::Map)] #[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] -pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec> +pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec>> where I: Iterator + ExactSizeIterator + Send + Debug, K: AsRef<[u8]> + Sized + Debug + 'a, @@ -87,19 +85,18 @@ where // comparator**. const SORTED: bool = false; - let mut ret: Vec> = Vec::with_capacity(keys.len()); let read_options = &self.read_options; - for res in self - .db + self.db .db .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) - { - match res { - Ok(Some(res)) => ret.push(Some((*res).to_vec())), - Ok(None) => ret.push(None), - Err(e) => or_else(e).expect("database multiget error"), - } - } + .into_iter() + .map(into_result_handle) + .collect() +} - ret +fn into_result_handle(result: RocksdbResult<'_>) -> Result> { + result + .map_err(util::map_err)? + .map(Handle::from) + .ok_or(err!(Request(NotFound("Not found in database")))) } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 66da39485..825ee109b 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use conduit::{err, implement, utils, Error, Result}; +use conduit::{err, implement, utils, Result}; use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; @@ -69,41 +69,26 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { #[implement(Service)] pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { - let mut ret: Vec = Vec::with_capacity(event_ids.len()); - let keys = event_ids - .iter() - .map(|id| id.as_bytes()) - .collect::>(); - - for (i, short) in self - .db + self.db .eventid_shorteventid - .get_batch_blocking(keys.iter()) - .iter() + .get_batch_blocking(event_ids.iter()) + .into_iter() .enumerate() - { - match short { - Some(short) => ret.push( - utils::u64_from_bytes(short) - .map_err(|_| Error::bad_database("Invalid shorteventid in db.")) - .unwrap(), - ), - None => { + .map(|(i, result)| match result { + Ok(ref short) => utils::u64_from_u8(short), + Err(_) => { let short = self.services.globals.next_count().unwrap(); self.db .eventid_shorteventid - .insert(keys[i], &short.to_be_bytes()); + .insert(event_ids[i], &short.to_be_bytes()); self.db .shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i]); + .insert(&short.to_be_bytes(), event_ids[i]); - debug_assert!(ret.len() == i, "position of result must match input"); - ret.push(short); + short }, - } - } - - ret + }) + .collect() } #[implement(Service)] From 36677bb9828038294d06f2292eef755139216c40 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 1 Oct 2024 23:19:47 +0000 Subject: [PATCH 024/245] optimize auth_chain short_id to event_id translation step Signed-off-by: Jason Volk --- src/service/rooms/auth_chain/mod.rs | 30 ++++++++++++++++++-------- src/service/rooms/event_handler/mod.rs | 16 +++++++------- src/service/rooms/short/mod.rs | 17 +++++++++++++++ 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index eae13b74a..f3861ca3f 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -6,7 +6,7 @@ use std::{ }; use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; -use futures::{FutureExt, Stream, StreamExt}; +use futures::Stream; use ruma::{EventId, RoomId}; use self::data::Data; @@ -40,15 +40,27 @@ impl Service { pub async fn event_ids_iter( &self, room_id: &RoomId, starting_events: &[&EventId], ) -> Result> + Send + '_> { + let stream = self + .get_event_ids(room_id, starting_events) + .await? + .into_iter() + .stream(); + + Ok(stream) + } + + pub async fn get_event_ids(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result>> { let chain = self.get_auth_chain(room_id, starting_events).await?; - let iter = chain.into_iter().stream().filter_map(|sid| { - self.services - .short - .get_eventid_from_short(sid) - .map(Result::ok) - }); - - Ok(iter) + let event_ids = self + .services + .short + .multi_get_eventid_from_short(&chain) + .await + .into_iter() + .filter_map(Result::ok) + .collect(); + + Ok(event_ids) } #[tracing::instrument(skip_all, name = "auth_chain")] diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 57b877064..4708a86cb 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -797,13 +797,13 @@ impl Service { for state in &fork_states { let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); - let auth_chain = self + let auth_chain: HashSet> = self .services .auth_chain - .event_ids_iter(room_id, &starting_events) + .get_event_ids(room_id, &starting_events) .await? - .collect::>>() - .await; + .into_iter() + .collect(); auth_chain_sets.push(auth_chain); } @@ -983,13 +983,13 @@ impl Service { starting_events.push(id.borrow()); } - let auth_chain = self + let auth_chain: HashSet> = self .services .auth_chain - .event_ids_iter(room_id, &starting_events) + .get_event_ids(room_id, &starting_events) .await? - .collect() - .await; + .into_iter() + .collect(); auth_chain_sets.push(auth_chain); fork_states.push(state); diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 825ee109b..20082da23 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -141,6 +141,23 @@ pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result Vec>> { + const BUFSIZE: usize = size_of::(); + + let keys: Vec<[u8; BUFSIZE]> = shorteventid + .iter() + .map(|short| short.to_be_bytes()) + .collect(); + + self.db + .shorteventid_eventid + .get_batch_blocking(keys.iter()) + .into_iter() + .map(Deserialized::deserialized) + .collect() +} + #[implement(Service)] pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { const BUFSIZE: usize = size_of::(); From 83119526291f25a78d67a15f638eeaedc0b10b2d Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 28 Sep 2024 18:30:40 -0400 Subject: [PATCH 025/245] bump ruma, cargo.lock, and deps Signed-off-by: strawberry --- Cargo.lock | 26 +++++++++++++------------- Cargo.toml | 2 +- src/api/client/capabilities.rs | 8 +++++++- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 065aa1e4a..4d40c4589 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2975,7 +2975,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "assign", "js_int", @@ -2997,7 +2997,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "ruma-common", @@ -3009,7 +3009,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "as_variant", "assign", @@ -3032,7 +3032,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "as_variant", "base64 0.22.1", @@ -3062,7 +3062,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3086,7 +3086,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "bytes", "http", @@ -3104,7 +3104,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "thiserror", @@ -3113,7 +3113,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "ruma-common", @@ -3123,7 +3123,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "cfg-if", "once_cell", @@ -3139,7 +3139,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "ruma-common", @@ -3151,7 +3151,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "headers", "http", @@ -3164,7 +3164,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3180,7 +3180,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "futures-util", "itertools 0.12.1", diff --git a/Cargo.toml b/Cargo.toml index 3bfb3bc81..28e280cfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -315,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "e7db44989d68406393270d3a91815597385d3acb" +rev = "ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" features = [ "compat", "rand", diff --git a/src/api/client/capabilities.rs b/src/api/client/capabilities.rs index 83e1dc7e6..89157e471 100644 --- a/src/api/client/capabilities.rs +++ b/src/api/client/capabilities.rs @@ -3,7 +3,8 @@ use std::collections::BTreeMap; use axum::extract::State; use ruma::{ api::client::discovery::get_capabilities::{ - self, Capabilities, RoomVersionStability, RoomVersionsCapability, ThirdPartyIdChangesCapability, + self, Capabilities, GetLoginTokenCapability, RoomVersionStability, RoomVersionsCapability, + ThirdPartyIdChangesCapability, }, RoomVersionId, }; @@ -43,6 +44,11 @@ pub(crate) async fn get_capabilities_route( enabled: false, }; + // we dont support generating tokens yet + capabilities.get_login_token = GetLoginTokenCapability { + enabled: false, + }; + // MSC4133 capability capabilities .set("uk.tcpip.msc4133.profile_fields", json!({"enabled": true})) From fafe32089980eefc5bda1cb8991a0be762c30e6b Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 28 Sep 2024 19:38:35 -0400 Subject: [PATCH 026/245] send EDUs to appservices if in events to_device is not supported yet Signed-off-by: strawberry --- src/service/sending/sender.rs | 39 ++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 4db9922ae..3a401995b 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -13,9 +13,14 @@ use conduit::{ }; use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ - api::federation::transactions::{ - edu::{DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap}, - send_transaction_message, + api::{ + appservice::event::push_events::v1::Edu as RumaEdu, + federation::transactions::{ + edu::{ + DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, + }, + send_transaction_message, + }, }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, @@ -441,7 +446,18 @@ impl Service { return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration"))))); }; - let mut pdu_jsons = Vec::new(); + let mut pdu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Pdu(_))) + .count(), + ); + let mut edu_jsons: Vec = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Edu(_))) + .count(), + ); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { @@ -449,10 +465,12 @@ impl Service { pdu_jsons.push(pdu.to_room_event()); } }, - SendingEvent::Edu(_) | SendingEvent::Flush => { - // Appservices don't need EDUs (?) and flush only; - // no new content + SendingEvent::Edu(edu) => { + if let Ok(edu) = serde_json::from_slice(edu) { + edu_jsons.push(edu); + } }, + SendingEvent::Flush => {}, // flush only; no new content } } @@ -466,7 +484,8 @@ impl Service { .collect::>(), )); - //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); + //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty + // transaction"); let client = &self.services.client.appservice; match appservice::send_request( client, @@ -474,8 +493,8 @@ impl Service { ruma::api::appservice::event::push_events::v1::Request { events: pdu_jsons, txn_id: txn_id.into(), - ephemeral: Vec::new(), - to_device: Vec::new(), + ephemeral: edu_jsons, + to_device: Vec::new(), // TODO }, ) .await From 890ee84f713c3f4905247934d5bfc277f61959cd Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 28 Sep 2024 21:44:38 -0400 Subject: [PATCH 027/245] dont send read receipts and typing indicators from ignored users Signed-off-by: strawberry --- src/api/client/sync.rs | 25 ++++++++++++++-- src/service/rooms/state_cache/mod.rs | 26 +--------------- src/service/rooms/typing/mod.rs | 45 +++++++++++++++++++++------- src/service/users/mod.rs | 32 ++++++++++++++++++-- 4 files changed, 88 insertions(+), 40 deletions(-) diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index adb4d8da7..1383f9022 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -1011,15 +1011,27 @@ async fn load_joined_room( .rooms .read_receipt .readreceipts_since(room_id, since) - .map(|(_, _, v)| v) + .filter_map(|(read_user, _, v)| async move { + (!services + .users + .user_is_ignored(&read_user, sender_user) + .await) + .then_some(v) + }) .collect() .await; if services.rooms.typing.last_typing_update(room_id).await? > since { edus.push( serde_json::from_str( - &serde_json::to_string(&services.rooms.typing.typings_all(room_id).await?) - .expect("event is valid, we just created it"), + &serde_json::to_string( + &services + .rooms + .typing + .typings_all(room_id, sender_user) + .await?, + ) + .expect("event is valid, we just created it"), ) .expect("event is valid, we just created it"), ); @@ -1583,6 +1595,13 @@ pub(crate) async fn sync_events_v4_route( .rooms .read_receipt .readreceipts_since(room_id, *roomsince) + .filter_map(|(read_user, ts, v)| async move { + (!services + .users + .user_is_ignored(&read_user, sender_user) + .await) + .then_some((read_user, ts, v)) + }) .collect() .await; diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index dbe385619..b1a71cafe 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -14,7 +14,6 @@ use itertools::Itertools; use ruma::{ events::{ direct::DirectEvent, - ignored_user_list::IgnoredUserListEvent, room::{ create::RoomCreateEventContent, member::{MembershipState, RoomMemberEventContent}, @@ -197,30 +196,7 @@ impl Service { }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver - let is_ignored = self - .services - .account_data - .get( - None, // Ignored users are in global account data - user_id, // Receiver - GlobalAccountDataEventType::IgnoredUserList - .to_string() - .into(), - ) - .await - .and_then(|event| { - serde_json::from_str::(event.get()) - .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) - }) - .map_or(false, |ignored| { - ignored - .content - .ignored_users - .iter() - .any(|(user, _details)| user == sender) - }); - - if is_ignored { + if self.services.users.user_is_ignored(sender, user_id).await { return Ok(()); } diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index bcfce6168..8ee34f44d 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -1,6 +1,11 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduit::{debug_info, trace, utils, Result, Server}; +use conduit::{ + debug_info, trace, + utils::{self, IterStream}, + Result, Server, +}; +use futures::StreamExt; use ruma::{ api::federation::transactions::edu::{Edu, TypingContent}, events::SyncEphemeralRoomEvent, @@ -8,7 +13,7 @@ use ruma::{ }; use tokio::sync::{broadcast, RwLock}; -use crate::{globals, sending, Dep}; +use crate::{globals, sending, users, Dep}; pub struct Service { server: Arc, @@ -23,6 +28,7 @@ pub struct Service { struct Services { globals: Dep, sending: Dep, + users: Dep, } impl crate::Service for Service { @@ -32,6 +38,7 @@ impl crate::Service for Service { services: Services { globals: args.depend::("globals"), sending: args.depend::("sending"), + users: args.depend::("users"), }, typing: RwLock::new(BTreeMap::new()), last_typing_update: RwLock::new(BTreeMap::new()), @@ -170,17 +177,35 @@ impl Service { /// Returns a new typing EDU. pub async fn typings_all( - &self, room_id: &RoomId, + &self, room_id: &RoomId, sender_user: &UserId, ) -> Result> { + let room_typing_indicators = self.typing.read().await.get(room_id).cloned(); + + let Some(typing_indicators) = room_typing_indicators else { + return Ok(SyncEphemeralRoomEvent { + content: ruma::events::typing::TypingEventContent { + user_ids: Vec::new(), + }, + }); + }; + + let user_ids: Vec<_> = typing_indicators + .into_keys() + .stream() + .filter_map(|typing_user_id| async move { + (!self + .services + .users + .user_is_ignored(&typing_user_id, sender_user) + .await) + .then_some(typing_user_id) + }) + .collect() + .await; + Ok(SyncEphemeralRoomEvent { content: ruma::events::typing::TypingEventContent { - user_ids: self - .typing - .read() - .await - .get(room_id) - .map(|m| m.keys().cloned().collect()) - .unwrap_or_default(), + user_ids, }, }) } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 438c220bc..1c079085e 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -10,13 +10,13 @@ use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt}; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, + events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType, StateEventType}, serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use crate::{admin, globals, rooms, Dep}; +use crate::{account_data, admin, globals, rooms, Dep}; pub struct Service { services: Services, @@ -25,6 +25,7 @@ pub struct Service { struct Services { server: Arc, + account_data: Dep, admin: Dep, globals: Dep, state_accessor: Dep, @@ -58,6 +59,7 @@ impl crate::Service for Service { Ok(Arc::new(Self { services: Services { server: args.server.clone(), + account_data: args.depend::("account_data"), admin: args.depend::("admin"), globals: args.depend::("globals"), state_accessor: args.depend::("rooms::state_accessor"), @@ -91,6 +93,32 @@ impl crate::Service for Service { } impl Service { + /// Returns true/false based on whether the recipient/receiving user has + /// blocked the sender + pub async fn user_is_ignored(&self, sender_user: &UserId, recipient_user: &UserId) -> bool { + self.services + .account_data + .get( + None, + recipient_user, + GlobalAccountDataEventType::IgnoredUserList + .to_string() + .into(), + ) + .await + .and_then(|event| { + serde_json::from_str::(event.get()) + .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) + }) + .map_or(false, |ignored| { + ignored + .content + .ignored_users + .keys() + .any(|blocked_user| blocked_user == sender_user) + }) + } + /// Check if a user is an admin #[inline] pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await } From 2083c38c764d5d144ff6355ce0688e8fe98d7d49 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 28 Sep 2024 22:12:17 -0400 Subject: [PATCH 028/245] dont send non-state events from ignored users over sync Signed-off-by: strawberry --- src/api/client/sync.rs | 91 +++++++++++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 23 deletions(-) diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 1383f9022..51df88a30 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -35,6 +35,7 @@ use ruma::{ presence::PresenceEvent, room::member::{MembershipState, RoomMemberEventContent}, AnyRawAccountDataEvent, StateEventType, TimelineEventType, + TimelineEventType::*, }, serde::Raw, state_res::Event, @@ -1004,8 +1005,31 @@ async fn load_joined_room( let room_events: Vec<_> = timeline_pdus .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) - .collect(); + .stream() + .filter_map(|(_, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote + | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + + Some(pdu.to_sync_room_event()) + }) + .collect() + .await; let mut edus: Vec<_> = services .rooms @@ -1144,11 +1168,11 @@ async fn share_encrypted_room( pub(crate) async fn sync_events_v4_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated"); let mut body = body.body; // Setup watchers, so if there's no response, we can wait for them - let watcher = services.globals.watch(&sender_user, &sender_device); + let watcher = services.globals.watch(sender_user, &sender_device); let next_batch = services.globals.next_count()?; @@ -1191,7 +1215,7 @@ pub(crate) async fn sync_events_v4_route( let all_joined_rooms: Vec<_> = services .rooms .state_cache - .rooms_joined(&sender_user) + .rooms_joined(sender_user) .map(ToOwned::to_owned) .collect() .await; @@ -1199,7 +1223,7 @@ pub(crate) async fn sync_events_v4_route( let all_invited_rooms: Vec<_> = services .rooms .state_cache - .rooms_invited(&sender_user) + .rooms_invited(sender_user) .map(|r| r.0) .collect() .await; @@ -1213,7 +1237,7 @@ pub(crate) async fn sync_events_v4_route( if body.extensions.to_device.enabled.unwrap_or(false) { services .users - .remove_to_device_events(&sender_user, &sender_device, globalsince) + .remove_to_device_events(sender_user, &sender_device, globalsince) .await; } @@ -1232,7 +1256,7 @@ pub(crate) async fn sync_events_v4_route( if body.extensions.account_data.enabled.unwrap_or(false) { account_data.global = services .account_data - .changes_since(None, &sender_user, globalsince) + .changes_since(None, sender_user, globalsince) .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) @@ -1244,7 +1268,7 @@ pub(crate) async fn sync_events_v4_route( room.clone(), services .account_data - .changes_since(Some(&room), &sender_user, globalsince) + .changes_since(Some(&room), sender_user, globalsince) .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) @@ -1338,7 +1362,7 @@ pub(crate) async fn sync_events_v4_route( let user_id = UserId::parse(state_key.clone()) .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - if user_id == sender_user { + if user_id == *sender_user { continue; } @@ -1350,7 +1374,7 @@ pub(crate) async fn sync_events_v4_route( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&services, &sender_user, &user_id, Some(room_id)) + if !share_encrypted_room(&services, sender_user, &user_id, Some(room_id)) .await { device_list_changes.insert(user_id); @@ -1367,7 +1391,6 @@ pub(crate) async fn sync_events_v4_route( } } if joined_since_last_sync || new_encrypted_room { - let sender_user = &sender_user; // If the user is in a new encrypted room, give them all joined users device_list_changes.extend( services @@ -1400,7 +1423,7 @@ pub(crate) async fn sync_events_v4_route( } for user_id in left_encrypted_users { - let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + let dont_share_encrypted_room = !share_encrypted_room(&services, sender_user, &user_id, None).await; // If the user doesn't share an encrypted room with the target anymore, we need // to tell them @@ -1564,14 +1587,14 @@ pub(crate) async fn sync_events_v4_route( invite_state = services .rooms .state_cache - .invite_state(&sender_user, room_id) + .invite_state(sender_user, room_id) .await .ok(); (timeline_pdus, limited) = (Vec::new(), true); } else { (timeline_pdus, limited) = - match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit).await { + match load_timeline(&services, sender_user, room_id, roomsincecount, *timeline_limit).await { Ok(value) => value, Err(err) => { warn!("Encountered missing timeline in {}, error {}", room_id, err); @@ -1584,7 +1607,7 @@ pub(crate) async fn sync_events_v4_route( room_id.clone(), services .account_data - .changes_since(Some(room_id), &sender_user, *roomsince) + .changes_since(Some(room_id), sender_user, *roomsince) .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) @@ -1639,8 +1662,30 @@ pub(crate) async fn sync_events_v4_route( let room_events: Vec<_> = timeline_pdus .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) - .collect(); + .stream() + .filter_map(|(_, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + + Some(pdu.to_sync_room_event()) + }) + .collect() + .await; for (_, pdu) in timeline_pdus { let ts = MilliSecondsSinceUnixEpoch(pdu.origin_server_ts); @@ -1669,7 +1714,7 @@ pub(crate) async fn sync_events_v4_route( .rooms .state_cache .room_members(room_id) - .ready_filter(|member| member != &sender_user) + .ready_filter(|member| member != sender_user) .filter_map(|user_id| { services .rooms @@ -1743,7 +1788,7 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .highlight_count(&sender_user, room_id) + .highlight_count(sender_user, room_id) .await .try_into() .expect("notification count can't go that high"), @@ -1752,7 +1797,7 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .notification_count(&sender_user, room_id) + .notification_count(sender_user, room_id) .await .try_into() .expect("notification count can't go that high"), @@ -1811,7 +1856,7 @@ pub(crate) async fn sync_events_v4_route( Some(sync_events::v4::ToDevice { events: services .users - .get_to_device_events(&sender_user, &sender_device) + .get_to_device_events(sender_user, &sender_device) .collect() .await, next_batch: next_batch.to_string(), @@ -1826,7 +1871,7 @@ pub(crate) async fn sync_events_v4_route( }, device_one_time_keys_count: services .users - .count_one_time_keys(&sender_user, &sender_device) + .count_one_time_keys(sender_user, &sender_device) .await, // Fallback keys are not yet supported device_unused_fallback_key_types: None, From 4413793f7e302c9e5b0880ba3eb3f20f8558e6b3 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 28 Sep 2024 23:15:43 -0400 Subject: [PATCH 029/245] dont allow sending/receiving room invites with ignored users Signed-off-by: strawberry --- src/api/client/membership.rs | 8 ++++++++ src/api/client/room.rs | 20 ++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 6e3bc8940..f89903b4f 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -364,6 +364,14 @@ pub(crate) async fn invite_user_route( user_id, } = &body.recipient { + if services.users.user_is_ignored(sender_user, user_id).await { + return Err!(Request(Forbidden("You cannot invite users you have ignored to rooms."))); + } else if services.users.user_is_ignored(user_id, sender_user).await { + // silently drop the invite to the recipient if they've been ignored by the + // sender, pretend it worked + return Ok(invite_user::v3::Response {}); + } + invite_helper(&services, sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; Ok(invite_user::v3::Response {}) } else { diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 1edf85d80..0d8e12a20 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -267,8 +267,16 @@ pub(crate) async fn create_room_route( let mut users = BTreeMap::from_iter([(sender_user.clone(), int!(100))]); if preset == RoomPreset::TrustedPrivateChat { - for invite_ in &body.invite { - users.insert(invite_.clone(), int!(100)); + for invite in &body.invite { + if services.users.user_is_ignored(sender_user, invite).await { + return Err!(Request(Forbidden("You cannot invite users you have ignored to rooms."))); + } else if services.users.user_is_ignored(invite, sender_user).await { + // silently drop the invite to the recipient if they've been ignored by the + // sender, pretend it worked + continue; + } + + users.insert(invite.clone(), int!(100)); } } @@ -476,6 +484,14 @@ pub(crate) async fn create_room_route( // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { + if services.users.user_is_ignored(sender_user, user_id).await { + return Err!(Request(Forbidden("You cannot invite users you have ignored to rooms."))); + } else if services.users.user_is_ignored(user_id, sender_user).await { + // silently drop the invite to the recipient if they've been ignored by the + // sender, pretend it worked + continue; + } + if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct) .boxed() .await From b64a23516520fe764e5d463a28d6b341942642c3 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 29 Sep 2024 00:28:05 -0400 Subject: [PATCH 030/245] use ok_or_else for a rare error Signed-off-by: strawberry --- src/service/rooms/event_handler/parse_incoming_pdu.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 2de3e28ef..9081fcbca 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -14,7 +14,7 @@ impl super::Service { let room_id: OwnedRoomId = value .get("room_id") .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or(err!(Request(InvalidParam("Invalid room id in pdu"))))?; + .ok_or_else(|| err!(Request(InvalidParam("Invalid room id in pdu"))))?; let Ok(room_version_id) = self.services.state.get_room_version(&room_id).await else { return Err!("Server is not in room {room_id}"); From ee1580e4800f254cee39707b8e0c3d0a9339bb23 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 29 Sep 2024 00:50:12 -0400 Subject: [PATCH 031/245] fix list_rooms admin command filters Signed-off-by: strawberry --- src/admin/room/commands.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/admin/room/commands.rs b/src/admin/room/commands.rs index 1c90a9983..35e40c8be 100644 --- a/src/admin/room/commands.rs +++ b/src/admin/room/commands.rs @@ -6,7 +6,7 @@ use crate::{admin_command, get_room_info, PAGE_SIZE}; #[admin_command] pub(super) async fn list_rooms( - &self, page: Option, _exclude_disabled: bool, _exclude_banned: bool, no_details: bool, + &self, page: Option, exclude_disabled: bool, exclude_banned: bool, no_details: bool, ) -> Result { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); @@ -15,8 +15,12 @@ pub(super) async fn list_rooms( .rooms .metadata .iter_ids() - //.filter(|room_id| async { !exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await }) - //.filter(|room_id| async { !exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await }) + .filter_map(|room_id| async move { + (!exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await).then_some(room_id) + }) + .filter_map(|room_id| async move { + (!exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await).then_some(room_id) + }) .then(|room_id| get_room_info(self.services, room_id)) .collect::>() .await; From 7a59add8f1bc8d4697580823e8651f4b72e4b9d5 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 29 Sep 2024 01:54:07 -0400 Subject: [PATCH 032/245] add support for reading a registration token from a file Signed-off-by: strawberry --- conduwuit-example.toml | 18 +++++++++++++---- docs/deploying/docker-compose.for-traefik.yml | 2 +- docs/deploying/docker-compose.with-caddy.yml | 2 +- .../deploying/docker-compose.with-traefik.yml | 6 +++--- docs/deploying/docker-compose.yml | 2 +- src/api/client/account.rs | 6 +++--- src/core/config/check.rs | 20 ++++++++++++++++++- src/core/config/mod.rs | 15 +++++++++++--- src/service/globals/mod.rs | 16 +++++++++++++++ src/service/uiaa/mod.rs | 12 +++++++---- 10 files changed, 78 insertions(+), 21 deletions(-) diff --git a/conduwuit-example.toml b/conduwuit-example.toml index b532d381f..117356165 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -195,11 +195,14 @@ allow_guests_auto_join_rooms = false # Enables registration. If set to false, no users can register on this # server. +# # If set to true without a token configured, users can register with no form of 2nd- # step only if you set # `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to -# true in your config. If you would like -# registration only via token reg, please configure the `registration_token` key. +# true in your config. +# +# If you would like registration only via token reg, please configure +# `registration_token` or `registration_token_file`. allow_registration = false # Please note that an open registration homeserver with no second-step verification # is highly prone to abuse and potential defederation by homeservers, including @@ -208,7 +211,14 @@ allow_registration = false # A static registration token that new users will have to provide when creating # an account. If unset and `allow_registration` is true, registration is open # without any condition. YOU NEED TO EDIT THIS. -registration_token = "change this token for something specific to your server" +registration_token = "change this token/string here or set registration_token_file" + +# Path to a file on the system that gets read for the registration token +# +# conduwuit must be able to access the file, and it must not be empty +# +# no default +#registration_token_file = "/etc/conduwuit/.reg_token" # controls whether federation is allowed or not # defaults to true @@ -344,7 +354,7 @@ allow_profile_lookup_federation_requests = true # Controls the max log level for admin command log captures (logs generated from running admin commands) # # Defaults to "info" on release builds, else "debug" on debug builds -#admin_log_capture = info +#admin_log_capture = "info" # Allows admins to enter commands in rooms other than #admins by prefixing with \!admin. The reply # will be publicly visible to the room, originating from the sender. diff --git a/docs/deploying/docker-compose.for-traefik.yml b/docs/deploying/docker-compose.for-traefik.yml index 1c615673a..ae93d52fa 100644 --- a/docs/deploying/docker-compose.for-traefik.yml +++ b/docs/deploying/docker-compose.for-traefik.yml @@ -16,7 +16,7 @@ services: CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 # should match the loadbalancer traefik label - CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' CONDUWUIT_ALLOW_FEDERATION: 'true' CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' diff --git a/docs/deploying/docker-compose.with-caddy.yml b/docs/deploying/docker-compose.with-caddy.yml index 899f4d679..369242126 100644 --- a/docs/deploying/docker-compose.with-caddy.yml +++ b/docs/deploying/docker-compose.with-caddy.yml @@ -32,7 +32,7 @@ services: CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 - CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' CONDUWUIT_ALLOW_FEDERATION: 'true' CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' diff --git a/docs/deploying/docker-compose.with-traefik.yml b/docs/deploying/docker-compose.with-traefik.yml index f05006a55..89118c742 100644 --- a/docs/deploying/docker-compose.with-traefik.yml +++ b/docs/deploying/docker-compose.with-traefik.yml @@ -15,7 +15,8 @@ services: CONDUWUIT_SERVER_NAME: your.server.name.example # EDIT THIS CONDUWUIT_TRUSTED_SERVERS: '["matrix.org"]' CONDUWUIT_ALLOW_REGISTRATION: 'false' # After setting a secure registration token, you can enable this - CONDUWUIT_REGISTRATION_TOKEN: # This is a token you can use to register on the server + CONDUWUIT_REGISTRATION_TOKEN: "" # This is a token you can use to register on the server + #CONDUWUIT_REGISTRATION_TOKEN_FILE: "" # Alternatively you can configure a path to a token file to read CONDUWUIT_ADDRESS: 0.0.0.0 CONDUWUIT_PORT: 6167 # you need to match this with the traefik load balancer label if you're want to change it CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit @@ -23,7 +24,6 @@ services: ### Uncomment and change values as desired, note that conduwuit has plenty of config options, so you should check out the example example config too # Available levels are: error, warn, info, debug, trace - more info at: https://docs.rs/env_logger/*/env_logger/#enabling-logging # CONDUWUIT_LOG: info # default is: "warn,state_res=warn" - # CONDUWUIT_ALLOW_JAEGER: 'false' # CONDUWUIT_ALLOW_ENCRYPTION: 'true' # CONDUWUIT_ALLOW_FEDERATION: 'true' # CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' @@ -31,7 +31,7 @@ services: # CONDUWUIT_ALLOW_OUTGOING_PRESENCE: true # CONDUWUIT_ALLOW_LOCAL_PRESENCE: true # CONDUWUIT_WORKERS: 10 - # CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + # CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB # CONDUWUIT_NEW_USER_DISPLAYNAME_SUFFIX = "🏳<200d>⚧" # We need some way to serve the client and server .well-known json. The simplest way is via the CONDUWUIT_WELL_KNOWN diff --git a/docs/deploying/docker-compose.yml b/docs/deploying/docker-compose.yml index bc9f24777..26145c5ae 100644 --- a/docs/deploying/docker-compose.yml +++ b/docs/deploying/docker-compose.yml @@ -16,7 +16,7 @@ services: CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 - CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' CONDUWUIT_ALLOW_FEDERATION: 'true' CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' diff --git a/src/api/client/account.rs b/src/api/client/account.rs index 63d02f8f8..1ededa368 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -111,7 +111,7 @@ pub(crate) async fn register_route( if is_guest && (!services.globals.allow_guest_registration() - || (services.globals.allow_registration() && services.globals.config.registration_token.is_some())) + || (services.globals.allow_registration() && services.globals.registration_token.is_some())) { info!( "Guest registration disabled / registration enabled with token configured, rejecting guest registration \ @@ -183,7 +183,7 @@ pub(crate) async fn register_route( // UIAA let mut uiaainfo; - let skip_auth = if services.globals.config.registration_token.is_some() { + let skip_auth = if services.globals.registration_token.is_some() { // Registration token required uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -685,7 +685,7 @@ pub(crate) async fn request_3pid_management_token_via_msisdn_route( pub(crate) async fn check_registration_token_validity( State(services): State, body: Ruma, ) -> Result { - let Some(reg_token) = services.globals.config.registration_token.clone() else { + let Some(reg_token) = services.globals.registration_token.clone() else { return Err(Error::BadRequest( ErrorKind::forbidden(), "Server does not allow token registration.", diff --git a/src/core/config/check.rs b/src/core/config/check.rs index 8dea55d83..c0d055337 100644 --- a/src/core/config/check.rs +++ b/src/core/config/check.rs @@ -94,6 +94,22 @@ pub fn check(config: &Config) -> Result<()> { )); } + // check if we can read the token file path, and check if the file is empty + if config.registration_token_file.as_ref().is_some_and(|path| { + let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| { + error!("Failed to read the registration token file: {e}"); + }) else { + return true; + }; + + token == String::new() + }) { + return Err!(Config( + "registration_token_file", + "Registration token file was specified but is empty or failed to be read" + )); + } + if config.max_request_size < 5_120_000 { return Err!(Config( "max_request_size", @@ -111,12 +127,13 @@ pub fn check(config: &Config) -> Result<()> { if config.allow_registration && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse && config.registration_token.is_none() + && config.registration_token_file.is_none() { return Err!(Config( "registration_token", "!! You have `allow_registration` enabled without a token configured in your config which means you are \ allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n -If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n +If this is not the intended behaviour, please set a registration token.\n For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you \ want, please set the following config option to true: `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`" @@ -126,6 +143,7 @@ For security and safety reasons, conduwuit will shut down. If you are extra sure if config.allow_registration && config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse && config.registration_token.is_none() + && config.registration_token_file.is_none() { warn!( "Open registration is enabled via setting \ diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index d8e1c7d93..126b3123e 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -139,6 +139,7 @@ pub struct Config { #[serde(default)] pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, pub registration_token: Option, + pub registration_token_file: Option, #[serde(default = "true_fn")] pub allow_encryption: bool, #[serde(default = "true_fn")] @@ -572,12 +573,20 @@ impl fmt::Display for Config { line("Allow registration", &self.allow_registration.to_string()); line( "Registration token", - if self.registration_token.is_some() { - "set" + if self.registration_token.is_none() && self.registration_token_file.is_none() && self.allow_registration { + "not set (⚠️ open registration!)" + } else if self.registration_token.is_none() && self.registration_token_file.is_none() { + "not set" } else { - "not set (open registration!)" + "set" }, ); + line( + "Registration token file path", + self.registration_token_file + .as_ref() + .map_or("", |path| path.to_str().unwrap_or_default()), + ); line( "Allow guest registration (inherently false if allow registration is false)", &self.allow_guest_registration.to_string(), diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index f24e8a274..fb970f078 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -41,6 +41,7 @@ pub struct Service { pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, pub turn_secret: String, + pub registration_token: Option, } type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries @@ -96,6 +97,20 @@ impl crate::Service for Service { }) }); + let registration_token = + config + .registration_token_file + .as_ref() + .map_or(config.registration_token.clone(), |path| { + let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| { + error!("Failed to read the registration token file: {e}"); + }) else { + return config.registration_token.clone(); + }; + + Some(token) + }); + let mut s = Self { db, config: config.clone(), @@ -112,6 +127,7 @@ impl crate::Service for Service { server_user: UserId::parse_with_server_name(String::from("conduit"), &config.server_name) .expect("@conduit:server_name is valid"), turn_secret, + registration_token, }; if !s diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 0415bfc23..f75f1bcd8 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -6,7 +6,7 @@ use std::{ use conduit::{ err, error, implement, utils, utils::{hash, string::EMPTY}, - Error, Result, Server, + Error, Result, }; use database::{Deserialized, Map}; use ruma::{ @@ -26,7 +26,6 @@ pub struct Service { } struct Services { - server: Arc, globals: Dep, users: Dep, } @@ -48,7 +47,6 @@ impl crate::Service for Service { userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), }, services: Services { - server: args.server.clone(), globals: args.depend::("globals"), users: args.depend::("users"), }, @@ -135,7 +133,13 @@ pub async fn try_auth( uiaainfo.completed.push(AuthType::Password); }, AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == self.services.server.config.registration_token.as_deref() { + if self + .services + .globals + .registration_token + .as_ref() + .is_some_and(|reg_token| t.token.trim() == reg_token) + { uiaainfo.completed.push(AuthType::RegistrationToken); } else { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { From 6a81bf23dec75e97be30c44f248ef9bd1493835e Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 29 Sep 2024 20:13:37 -0400 Subject: [PATCH 033/245] dont send events from ignored users over /messages Signed-off-by: strawberry --- src/api/client/message.rs | 80 +++++++++++++++++++++++++++++++-------- 1 file changed, 65 insertions(+), 15 deletions(-) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index bab5fa54f..d577e3c83 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,7 +1,11 @@ use std::collections::{BTreeMap, HashSet}; use axum::extract::State; -use conduit::{err, utils::ReadyExt, Err, PduCount}; +use conduit::{ + err, + utils::{IterStream, ReadyExt}, + Err, PduCount, +}; use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ @@ -9,7 +13,7 @@ use ruma::{ filter::{RoomEventFilter, UrlFilter}, message::{get_message_events, send_message_event}, }, - events::{MessageLikeEventType, StateEventType}, + events::{MessageLikeEventType, StateEventType, TimelineEventType::*}, UserId, }; use serde_json::{from_str, Value}; @@ -182,8 +186,30 @@ pub(crate) async fn get_message_events_route( let events_after: Vec<_> = events_after .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + .stream() + .filter_map(|(_, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + + Some(pdu.to_room_event()) + }) + .collect() + .await; resp.start = from.stringify(); resp.end = next_token.map(|count| count.stringify()); @@ -203,6 +229,27 @@ pub(crate) async fn get_message_events_route( .pdus_until(sender_user, room_id, from) .await? .ready_filter_map(|item| contains_url_filter(item, filter)) + .filter_map(|(count, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + + Some((count, pdu)) + }) .filter_map(|item| visibility_filter(&services, item, sender_user)) .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) @@ -243,17 +290,20 @@ pub(crate) async fn get_message_events_route( }, } - resp.state = Vec::new(); - for ll_id in &lazy_loaded { - if let Ok(member_event) = services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, ll_id.as_str()) - .await - { - resp.state.push(member_event.to_state_event()); - } - } + resp.state = lazy_loaded + .iter() + .stream() + .filter_map(|ll_user_id| async move { + services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, ll_user_id.as_str()) + .await + .map(|member_event| member_event.to_state_event()) + .ok() + }) + .collect() + .await; // remove the feature check when we are sure clients like element can handle it if !cfg!(feature = "element_hacks") { From a9e3e8f77ad38549f7bb21c9447e5e7549ac31fe Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 29 Sep 2024 20:40:54 -0400 Subject: [PATCH 034/245] dont send non-state events from ignored users over /context/{eventId} Signed-off-by: strawberry --- src/api/client/context.rs | 64 ++++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/src/api/client/context.rs b/src/api/client/context.rs index cc49b763f..9a5c4e826 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -5,12 +5,12 @@ use conduit::{err, error, Err}; use futures::StreamExt; use ruma::{ api::client::{context::get_context, filter::LazyLoadOptions}, - events::StateEventType, + events::{StateEventType, TimelineEventType::*}, }; use crate::{Result, Ruma}; -/// # `GET /_matrix/client/r0/rooms/{roomId}/context` +/// # `GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}` /// /// Allows loading room history around an event. /// @@ -31,7 +31,7 @@ pub(crate) async fn get_context_route( LazyLoadOptions::Disabled => (false, cfg!(feature = "element_hacks")), }; - let mut lazy_loaded = HashSet::new(); + let mut lazy_loaded = HashSet::with_capacity(100); let base_token = services .rooms @@ -79,6 +79,25 @@ pub(crate) async fn get_context_route( .await? .take(limit / 2) .filter_map(|(count, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote + | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + services .rooms .state_accessor @@ -104,11 +123,6 @@ pub(crate) async fn get_context_route( .last() .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); - let events_after: Vec<_> = services .rooms .timeline @@ -116,6 +130,25 @@ pub(crate) async fn get_context_route( .await? .take(limit / 2) .filter_map(|(count, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote + | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + services .rooms .state_accessor @@ -167,11 +200,6 @@ pub(crate) async fn get_context_route( .last() .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - let events_after: Vec<_> = events_after - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); - let mut state = Vec::with_capacity(state_ids.len()); for (shortstatekey, id) in state_ids { @@ -201,9 +229,15 @@ pub(crate) async fn get_context_route( Ok(get_context::v3::Response { start: Some(start_token), end: Some(end_token), - events_before, + events_before: events_before + .iter() + .map(|(_, pdu)| pdu.to_room_event()) + .collect(), event: Some(base_event), - events_after, + events_after: events_after + .iter() + .map(|(_, pdu)| pdu.to_room_event()) + .collect(), state, }) } From 115ea03edfc1cf785ad280abcc850bf14a2b76cc Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 29 Sep 2024 20:57:33 -0400 Subject: [PATCH 035/245] remove unnecessary full type annos Signed-off-by: strawberry --- src/api/client/sync.rs | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 51df88a30..c4ff1eeb5 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -34,8 +34,8 @@ use ruma::{ events::{ presence::PresenceEvent, room::member::{MembershipState, RoomMemberEventContent}, - AnyRawAccountDataEvent, StateEventType, TimelineEventType, - TimelineEventType::*, + AnyRawAccountDataEvent, StateEventType, + TimelineEventType::{self, *}, }, serde::Raw, state_res::Event, @@ -50,14 +50,8 @@ use crate::{ }; const SINGLE_CONNECTION_SYNC: &str = "single_connection_sync"; -const DEFAULT_BUMP_TYPES: &[TimelineEventType] = &[ - TimelineEventType::RoomMessage, - TimelineEventType::RoomEncrypted, - TimelineEventType::Sticker, - TimelineEventType::CallInvite, - TimelineEventType::PollStart, - TimelineEventType::Beacon, -]; +const DEFAULT_BUMP_TYPES: &[TimelineEventType; 6] = + &[RoomMessage, RoomEncrypted, Sticker, CallInvite, PollStart, Beacon]; macro_rules! extract_variant { ($e:expr, $variant:path) => { @@ -376,7 +370,7 @@ async fn handle_left_room( origin_server_ts: utils::millis_since_unix_epoch() .try_into() .expect("Timestamp is valid js_int value"), - kind: TimelineEventType::RoomMember, + kind: RoomMember, content: serde_json::from_str(r#"{"membership":"leave"}"#).expect("this is valid JSON"), state_key: Some(sender_user.to_string()), unsigned: None, @@ -639,7 +633,7 @@ async fn load_joined_room( .timeline .all_pdus(sender_user, room_id) .await? - .ready_filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) + .ready_filter(|(_, pdu)| pdu.kind == RoomMember) .filter_map(|(_, pdu)| async move { let Ok(content) = serde_json::from_str::(pdu.content.get()) else { return None; @@ -827,11 +821,11 @@ async fn load_joined_room( let send_member_count = delta_state_events .iter() - .any(|event| event.kind == TimelineEventType::RoomMember); + .any(|event| event.kind == RoomMember); if encrypted_room { for state_event in &delta_state_events { - if state_event.kind != TimelineEventType::RoomMember { + if state_event.kind != RoomMember { continue; } @@ -895,7 +889,7 @@ async fn load_joined_room( // Mark all member events we're returning as lazy-loaded for pdu in &state_events { - if pdu.kind == TimelineEventType::RoomMember { + if pdu.kind == RoomMember { match UserId::parse( pdu.state_key .as_ref() @@ -1357,7 +1351,7 @@ pub(crate) async fn sync_events_v4_route( error!("Pdu in state not found: {id}"); continue; }; - if pdu.kind == TimelineEventType::RoomMember { + if pdu.kind == RoomMember { if let Some(state_key) = &pdu.state_key { let user_id = UserId::parse(state_key.clone()) .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; From 4eb7ad79d18c1d89784d2756a8652a1c0ce2d347 Mon Sep 17 00:00:00 2001 From: strawberry Date: Tue, 1 Oct 2024 01:59:24 -0400 Subject: [PATCH 036/245] update last_seen_ip and last_seen_ts on updating device metadata Signed-off-by: strawberry --- src/api/client/device.rs | 20 +++++++++++++++----- src/service/users/mod.rs | 11 ----------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/api/client/device.rs b/src/api/client/device.rs index 93eaa393d..7e56f27e1 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -1,10 +1,14 @@ use axum::extract::State; +use axum_client_ip::InsecureClientIp; use conduit::{err, Err}; use futures::StreamExt; -use ruma::api::client::{ - device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, - error::ErrorKind, - uiaa::{AuthFlow, AuthType, UiaaInfo}, +use ruma::{ + api::client::{ + device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, + error::ErrorKind, + uiaa::{AuthFlow, AuthType, UiaaInfo}, + }, + MilliSecondsSinceUnixEpoch, }; use super::SESSION_ID_LENGTH; @@ -51,8 +55,10 @@ pub(crate) async fn get_device_route( /// # `PUT /_matrix/client/r0/devices/{deviceId}` /// /// Updates the metadata on a given device of the sender user. +#[tracing::instrument(skip_all, fields(%client), name = "update_device")] pub(crate) async fn update_device_route( - State(services): State, body: Ruma, + State(services): State, InsecureClientIp(client): InsecureClientIp, + body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -63,6 +69,10 @@ pub(crate) async fn update_device_route( .map_err(|_| err!(Request(NotFound("Device not found."))))?; device.display_name.clone_from(&body.display_name); + device.last_seen_ip.clone_from(&Some(client.to_string())); + device + .last_seen_ts + .clone_from(&Some(MilliSecondsSinceUnixEpoch::now())); services .users diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 1c079085e..44d169dd4 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -791,17 +791,6 @@ impl Service { } pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - // Only existing devices should be able to call this, but we shouldn't assert - // either... - let key = (user_id, device_id); - if self.db.userdeviceid_metadata.qry(&key).await.is_err() { - return Err!(Database(error!( - ?user_id, - ?device_id, - "Called update_device_metadata for a non-existent user and/or device" - ))); - } - increment(&self.db.userid_devicelistversion, user_id.as_bytes()); let mut userdeviceid = user_id.as_bytes().to_vec(); From 98363852b18c4f2cc1525c671b770a6fbf4a7f3a Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 2 Oct 2024 00:56:09 -0400 Subject: [PATCH 037/245] fix: dont add remote users for push targets, use hashset instead of vec Signed-off-by: strawberry --- src/service/rooms/timeline/mod.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 6a26a1d53..f8f770bc4 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -43,7 +43,7 @@ use self::data::Data; pub use self::data::PdusIterItem; use crate::{ account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, - rooms::state_compressor::CompressedStateEvent, sending, server_keys, Dep, + rooms::state_compressor::CompressedStateEvent, sending, server_keys, users, Dep, }; // Update Relationships @@ -90,6 +90,7 @@ struct Services { sending: Dep, server_keys: Dep, user: Dep, + users: Dep, pusher: Dep, threads: Dep, search: Dep, @@ -119,6 +120,7 @@ impl crate::Service for Service { sending: args.depend::("sending"), server_keys: args.depend::("server_keys"), user: args.depend::("rooms::user"), + users: args.depend::("users"), pusher: args.depend::("pusher"), threads: args.depend::("rooms::threads"), search: args.depend::("rooms::search"), @@ -378,20 +380,20 @@ impl Service { let mut notifies = Vec::new(); let mut highlights = Vec::new(); - let mut push_target = self + let mut push_target: HashSet<_> = self .services .state_cache .active_local_users_in_room(&pdu.room_id) .map(ToOwned::to_owned) - .collect::>() + .collect() .await; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { - let target_user_id = UserId::parse(state_key.clone()).expect("This state_key was previously validated"); + let target_user_id = UserId::parse(state_key.clone())?; - if !push_target.contains(&target_user_id) { - push_target.push(target_user_id); + if self.services.users.is_active_local(&target_user_id).await { + push_target.insert(target_user_id); } } } From 54a107c3c473049f5049e96f807d4c505f5a13db Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 2 Oct 2024 01:47:19 -0400 Subject: [PATCH 038/245] drop unnecessary error to debug_warn Signed-off-by: strawberry --- src/api/client/state.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/client/state.rs b/src/api/client/state.rs index f9a4a7636..d89c23e8c 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -130,7 +130,7 @@ pub(crate) async fn get_state_events_for_key_route( .room_state_get(&body.room_id, &body.event_type, &body.state_key) .await .map_err(|_| { - err!(Request(NotFound(error!( + err!(Request(NotFound(debug_warn!( room_id = ?body.room_id, event_type = ?body.event_type, "State event not found in room.", From ab9a65db5d8501c60c523bc69d704032459db482 Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 2 Oct 2024 01:47:53 -0400 Subject: [PATCH 039/245] add MSC4151 room reporting support Signed-off-by: strawberry --- Cargo.lock | 28 +++++----- Cargo.toml | 2 +- src/api/client/report.rs | 107 +++++++++++++++++++++++++++------------ src/api/router.rs | 1 + 4 files changed, 90 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4d40c4589..e72c7e805 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2975,7 +2975,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "assign", "js_int", @@ -2997,7 +2997,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "js_int", "ruma-common", @@ -3009,7 +3009,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "as_variant", "assign", @@ -3032,7 +3032,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "as_variant", "base64 0.22.1", @@ -3062,7 +3062,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3086,7 +3086,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "bytes", "http", @@ -3104,7 +3104,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "js_int", "thiserror", @@ -3113,7 +3113,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "js_int", "ruma-common", @@ -3123,7 +3123,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "cfg-if", "once_cell", @@ -3139,7 +3139,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "js_int", "ruma-common", @@ -3151,7 +3151,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "headers", "http", @@ -3164,7 +3164,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3180,10 +3180,10 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" dependencies = [ "futures-util", - "itertools 0.12.1", + "itertools 0.13.0", "js_int", "ruma-common", "ruma-events", diff --git a/Cargo.toml b/Cargo.toml index 28e280cfd..18f33375f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -315,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" +rev = "e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" features = [ "compat", "rand", diff --git a/src/api/client/report.rs b/src/api/client/report.rs index a40c35a28..cf789246a 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -1,10 +1,14 @@ use std::time::Duration; use axum::extract::State; +use axum_client_ip::InsecureClientIp; use conduit::{utils::ReadyExt, Err}; use rand::Rng; use ruma::{ - api::client::{error::ErrorKind, room::report_content}, + api::client::{ + error::ErrorKind, + room::{report_content, report_room}, + }, events::room::message, int, EventId, RoomId, UserId, }; @@ -14,22 +18,75 @@ use tracing::info; use crate::{ debug_info, service::{pdu::PduEvent, Services}, - utils::HtmlEscape, Error, Result, Ruma, }; +/// # `POST /_matrix/client/v3/rooms/{roomId}/report` +/// +/// Reports an abusive room to homeserver admins +#[tracing::instrument(skip_all, fields(%client), name = "report_room")] +pub(crate) async fn report_room_route( + State(services): State, InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + // user authentication + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + info!( + "Received room report by user {sender_user} for room {} with reason: {:?}", + body.room_id, body.reason + ); + + delay_response().await; + + if !services + .rooms + .state_cache + .server_in_room(&services.globals.config.server_name, &body.room_id) + .await + { + return Err!(Request(NotFound( + "Room does not exist to us, no local users have joined at all" + ))); + } + + if body.reason.as_ref().is_some_and(|s| s.len() > 750) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Reason too long, should be 750 characters or fewer", + )); + }; + + // send admin room message that we received the report with an @room ping for + // urgency + services + .admin + .send_message(message::RoomMessageEventContent::text_markdown(format!( + "@room Room report received from {} -\n\nRoom ID: {}\n\nReport Reason: {}", + sender_user.to_owned(), + body.room_id, + body.reason.as_deref().unwrap_or("") + ))) + .await + .ok(); + + Ok(report_room::v3::Response {}) +} + /// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}` /// /// Reports an inappropriate event to homeserver admins +#[tracing::instrument(skip_all, fields(%client), name = "report_event")] pub(crate) async fn report_event_route( - State(services): State, body: Ruma, + State(services): State, InsecureClientIp(client): InsecureClientIp, + body: Ruma, ) -> Result { // user authentication let sender_user = body.sender_user.as_ref().expect("user is authenticated"); info!( - "Received /report request by user {sender_user} for room {} and event ID {}", - body.room_id, body.event_id + "Received event report by user {sender_user} for room {} and event ID {}, with reason: {:?}", + body.room_id, body.event_id, body.reason ); delay_response().await; @@ -39,7 +96,7 @@ pub(crate) async fn report_event_route( return Err!(Request(NotFound("Event ID is not known to us or Event ID is invalid"))); }; - is_report_valid( + is_event_report_valid( &services, &pdu.event_id, &body.room_id, @@ -54,32 +111,16 @@ pub(crate) async fn report_event_route( // urgency services .admin - .send_message(message::RoomMessageEventContent::text_html( - format!( - "@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \ - Reason: {}", - sender_user.to_owned(), - pdu.event_id, - pdu.room_id, - pdu.sender.clone(), - body.score.unwrap_or_else(|| ruma::Int::from(0)), - body.reason.as_deref().unwrap_or("") - ), - format!( - "

    @room Report received from: {0}\ -
    • Event Info
      • Event ID: {1}\ - 🔗
      • Room ID: {2}\ -
      • Sent By: {3}
    • \ - Report Info
      • Report Score: {4}
      • Report Reason: {5}
    • \ -
    ", - sender_user.to_owned(), - pdu.event_id.clone(), - pdu.room_id.clone(), - pdu.sender.clone(), - body.score.unwrap_or_else(|| ruma::Int::from(0)), - HtmlEscape(body.reason.as_deref().unwrap_or("")) - ), - )) + .send_message(message::RoomMessageEventContent::text_markdown(format!( + "@room Event report received from {} -\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: \ + {}\nReport Reason: {}", + sender_user.to_owned(), + pdu.event_id, + pdu.room_id, + pdu.sender, + body.score.unwrap_or_else(|| ruma::Int::from(0)), + body.reason.as_deref().unwrap_or("") + ))) .await .ok(); @@ -92,7 +133,7 @@ pub(crate) async fn report_event_route( /// check if score is in valid range /// check if report reasoning is less than or equal to 750 characters /// check if reporting user is in the reporting room -async fn is_report_valid( +async fn is_event_report_valid( services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option, score: Option, pdu: &std::sync::Arc, ) -> Result<()> { diff --git a/src/api/router.rs b/src/api/router.rs index c4275f054..ddd91d11f 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -91,6 +91,7 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::create_room_route) .ruma_route(&client::redact_event_route) .ruma_route(&client::report_event_route) + .ruma_route(&client::report_room_route) .ruma_route(&client::create_alias_route) .ruma_route(&client::delete_alias_route) .ruma_route(&client::get_alias_route) From bd56d8304561bae45f8578a61e89e020a8387888 Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 2 Oct 2024 09:26:28 -0400 Subject: [PATCH 040/245] fix room directory regression Signed-off-by: strawberry --- src/service/rooms/directory/mod.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 5666a91a7..2112ecefb 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use conduit::{implement, utils::stream::TryIgnore, Result}; -use database::{Ignore, Map}; -use futures::{Stream, StreamExt}; +use database::Map; +use futures::Stream; use ruma::RoomId; pub struct Service { @@ -35,10 +35,4 @@ pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(ro pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.get(room_id).await.is_ok() } #[implement(Service)] -pub fn public_rooms(&self) -> impl Stream + Send { - self.db - .publicroomids - .keys() - .ignore_err() - .map(|(room_id, _): (&RoomId, Ignore)| room_id) -} +pub fn public_rooms(&self) -> impl Stream + Send { self.db.publicroomids.keys().ignore_err() } From fa7c1200b55a1d90df57dc31c3d54c92fb89fff0 Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 2 Oct 2024 21:38:52 -0400 Subject: [PATCH 041/245] miniscule spaces code optimisations still terrible though Signed-off-by: strawberry --- src/service/rooms/spaces/mod.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 17fbf0ef0..920424a42 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -62,11 +62,11 @@ impl FromStr for PaginationToken { let mut values = value.split('_'); let mut pag_tok = || { - let mut rooms = vec![]; - - for room in values.next()?.split(',') { - rooms.push(u64::from_str(room).ok()?); - } + let rooms = values + .next()? + .split(',') + .filter_map(|room_s| u64::from_str(room_s).ok()) + .collect(); Some(Self { short_room_ids: rooms, @@ -469,7 +469,7 @@ impl Service { }, )]]; - let mut results = Vec::new(); + let mut results = Vec::with_capacity(limit); while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } { if results.len() >= limit { @@ -548,11 +548,12 @@ impl Service { parents.pop_front(); parents.push_back(room); - let mut short_room_ids = vec![]; - - for room in parents { - short_room_ids.push(self.services.short.get_or_create_shortroomid(&room).await); - } + let short_room_ids: Vec<_> = parents + .iter() + .stream() + .filter_map(|room_id| async move { self.services.short.get_shortroomid(room_id).await.ok() }) + .collect() + .await; Some( PaginationToken { @@ -585,7 +586,7 @@ impl Service { .await .map_err(|e| err!(Database("State in space not found: {e}")))?; - let mut children_pdus = Vec::new(); + let mut children_pdus = Vec::with_capacity(state.len()); for (key, id) in state { let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?; From c6b7c24e99891a8374a7444048c493240f7dbca5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 21:42:25 +0000 Subject: [PATCH 042/245] consume all bytes for top-level Ignore; add comments/tweaks Signed-off-by: Jason Volk --- src/database/de.rs | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/src/database/de.rs b/src/database/de.rs index fc36560d6..9ee52267e 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -12,6 +12,7 @@ where let mut deserializer = Deserializer { buf, pos: 0, + seq: false, }; T::deserialize(&mut deserializer).debug_inspect(|_| { @@ -24,6 +25,7 @@ where pub(crate) struct Deserializer<'de> { buf: &'de [u8], pos: usize, + seq: bool, } /// Directive to ignore a record. This type can be used to skip deserialization @@ -32,8 +34,11 @@ pub(crate) struct Deserializer<'de> { pub struct Ignore; impl<'de> Deserializer<'de> { + /// Record separator; an intentionally invalid-utf8 byte. const SEP: u8 = b'\xFF'; + /// Determine if the input was fully consumed and error if bytes remaining. + /// This is intended for debug assertions; not optimized for parsing logic. fn finished(&self) -> Result<()> { let pos = self.pos; let len = self.buf.len(); @@ -48,6 +53,20 @@ impl<'de> Deserializer<'de> { ))) } + /// Consume the current record to ignore it. Inside a sequence the next + /// record is skipped but at the top-level all records are skipped such that + /// deserialization completes with self.finished() == Ok. + #[inline] + fn record_ignore(&mut self) { + if self.seq { + self.record_next(); + } else { + self.record_trail(); + } + } + + /// Consume the current record. The position pointer is moved to the start + /// of the next record. Slice of the current record is returned. #[inline] fn record_next(&mut self) -> &'de [u8] { self.buf[self.pos..] @@ -57,8 +76,10 @@ impl<'de> Deserializer<'de> { .expect("remainder of buf even if SEP was not found") } + /// Peek at the first byte of the current record. If all records were + /// consumed None is returned instead. #[inline] - fn record_next_peek_byte(&self) -> Option { + fn record_peek_byte(&self) -> Option { let started = self.pos != 0; let buf = &self.buf[self.pos..]; debug_assert!( @@ -69,6 +90,8 @@ impl<'de> Deserializer<'de> { buf.get::(started.into()).copied() } + /// Consume the record separator such that the position cleanly points to + /// the start of the next record. (Case for some sequences) #[inline] fn record_start(&mut self) { let started = self.pos != 0; @@ -78,8 +101,11 @@ impl<'de> Deserializer<'de> { ); self.inc_pos(started.into()); + self.seq = true; } + /// Consume all remaining bytes, which may include record separators, + /// returning a raw slice. #[inline] fn record_trail(&mut self) -> &'de [u8] { let record = &self.buf[self.pos..]; @@ -87,6 +113,7 @@ impl<'de> Deserializer<'de> { record } + /// Increment the position pointer. #[inline] fn inc_pos(&mut self, n: usize) { self.pos = self.pos.saturating_add(n); @@ -142,7 +169,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { match name { - "Ignore" => self.record_next(), + "Ignore" => self.record_ignore(), _ => unimplemented!("Unrecognized deserialization Directive {name:?}"), }; @@ -190,7 +217,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { fn deserialize_i64>(self, visitor: V) -> Result { let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; - self.pos = self.pos.saturating_add(size_of::()); + self.inc_pos(size_of::()); visitor.visit_i64(i64::from_be_bytes(bytes)) } @@ -208,7 +235,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { fn deserialize_u64>(self, visitor: V) -> Result { let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; - self.pos = self.pos.saturating_add(size_of::()); + self.inc_pos(size_of::()); visitor.visit_u64(u64::from_be_bytes(bytes)) } @@ -267,7 +294,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { "deserialize_any: type not expected" ); - match self.record_next_peek_byte() { + match self.record_peek_byte() { Some(b'{') => self.deserialize_map(visitor), _ => self.deserialize_str(visitor), } From 2d049dacc37a7f3c7265006a8ebc39516ce7ee55 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 4 Oct 2024 02:23:28 +0000 Subject: [PATCH 043/245] fix get_all_media_keys deserialization Signed-off-by: Jason Volk --- src/service/media/data.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 248e9e1d2..b22718836 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -122,8 +122,9 @@ impl Data { let keys: Vec> = self .mediaid_file - .keys_prefix_raw(&prefix) + .raw_keys_prefix(&prefix) .ignore_err() + .map(<[u8]>::to_vec) .collect() .await; From bd9a9cc5f84066d7131ded8263a46b1ab57667b8 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 22:03:39 +0000 Subject: [PATCH 044/245] fix trait-solver issue requiring recursion_limit increase Signed-off-by: Jason Volk --- src/api/mod.rs | 2 -- src/main/main.rs | 2 -- src/router/mod.rs | 2 -- src/service/mod.rs | 1 - src/service/service.rs | 47 +++++++++++++++++++++++++++++++---------- src/service/services.rs | 2 ++ 6 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 82b857db3..96837470b 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,3 @@ -#![recursion_limit = "192"] - pub mod client; pub mod router; pub mod server; diff --git a/src/main/main.rs b/src/main/main.rs index 8703eef2b..8e644a158 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -1,5 +1,3 @@ -#![recursion_limit = "192"] - pub(crate) mod clap; mod mods; mod restart; diff --git a/src/router/mod.rs b/src/router/mod.rs index 67ebc0e3f..e123442ca 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,5 +1,3 @@ -#![recursion_limit = "160"] - mod layers; mod request; mod router; diff --git a/src/service/mod.rs b/src/service/mod.rs index cb8bfcd95..604e34045 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,4 +1,3 @@ -#![recursion_limit = "192"] #![allow(refining_impl_trait)] mod manager; diff --git a/src/service/service.rs b/src/service/service.rs index 031650506..7ec2ea0fe 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -51,7 +51,7 @@ pub(crate) struct Args<'a> { /// Dep is a reference to a service used within another service. /// Circular-dependencies between services require this indirection. -pub(crate) struct Dep { +pub(crate) struct Dep { dep: OnceLock>, service: Weak, name: &'static str, @@ -62,24 +62,47 @@ pub(crate) type MapType = BTreeMap; pub(crate) type MapVal = (Weak, Weak); pub(crate) type MapKey = String; -impl Deref for Dep { +/// SAFETY: Workaround for a compiler limitation (or bug) where it is Hard to +/// prove the Sync'ness of Dep because services contain circular references +/// to other services through Dep's. The Sync'ness of Dep can still be +/// proved without unsafety by declaring the crate-attribute #![recursion_limit +/// = "192"] but this may take a while. Re-evaluate this when a new trait-solver +/// (such as Chalk) becomes available. +unsafe impl Sync for Dep {} + +/// SAFETY: Ancillary to unsafe impl Sync; while this is not needed to prevent +/// violating the recursion_limit, the trait-solver still spends an inordinate +/// amount of time to prove this. +unsafe impl Send for Dep {} + +impl Deref for Dep { type Target = Arc; /// Dereference a dependency. The dependency must be ready or panics. + #[inline] fn deref(&self) -> &Self::Target { - self.dep.get_or_init(|| { - let service = self - .service - .upgrade() - .expect("services map exists for dependency initialization."); - - require::(&service, self.name) - }) + self.dep.get_or_init( + #[inline(never)] + || self.init(), + ) + } +} + +impl Dep { + #[inline] + fn init(&self) -> Arc { + let service = self + .service + .upgrade() + .expect("services map exists for dependency initialization."); + + require::(&service, self.name) } } impl<'a> Args<'a> { /// Create a lazy-reference to a service when constructing another Service. + #[inline] pub(crate) fn depend(&'a self, name: &'static str) -> Dep { Dep:: { dep: OnceLock::new(), @@ -90,12 +113,14 @@ impl<'a> Args<'a> { /// Create a reference immediately to a service when constructing another /// Service. The other service must be constructed. + #[inline] pub(crate) fn require(&'a self, name: &str) -> Arc { require::(self.service, name) } } /// Reference a Service by name. Panics if the Service does not exist or was /// incorrectly cast. -pub(crate) fn require(map: &Map, name: &str) -> Arc { +#[inline] +fn require(map: &Map, name: &str) -> Arc { try_get::(map, name) .inspect_err(inspect_log) .expect("Failure to reference service required by another service.") diff --git a/src/service/services.rs b/src/service/services.rs index da22fb2d4..0b63a5cae 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -195,6 +195,7 @@ impl Services { } } + #[inline] pub fn try_get(&self, name: &str) -> Result> where T: Any + Send + Sync + Sized, @@ -202,6 +203,7 @@ impl Services { service::try_get::(&self.service, name) } + #[inline] pub fn get(&self, name: &str) -> Option> where T: Any + Send + Sync + Sized, From ba683cf5340ff4321b8e8789b101d923b07bd9d4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 4 Oct 2024 17:17:10 +0000 Subject: [PATCH 045/245] fix aliasid_alias key deserialization Signed-off-by: Jason Volk --- src/service/rooms/alias/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 1d44cd2d8..f50cc46c0 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -101,9 +101,9 @@ impl Service { let prefix = (&room_id, Interfix); self.db .aliasid_alias - .keys_prefix(&prefix) + .keys_raw_prefix(&prefix) .ignore_err() - .ready_for_each(|key: &[u8]| self.db.aliasid_alias.remove(&key)) + .ready_for_each(|key| self.db.aliasid_alias.remove(key)) .await; self.db.alias_roomid.remove(alias.as_bytes()); @@ -161,7 +161,7 @@ impl Service { .aliasid_alias .stream_prefix(&prefix) .ignore_err() - .map(|((Ignore, Ignore), alias): ((Ignore, Ignore), &RoomAliasId)| alias) + .map(|(_, alias): (Ignore, &RoomAliasId)| alias) } #[tracing::instrument(skip(self), level = "debug")] From 9eace1fbbb8eaaf819d12649e0a777dd5c7f4cf6 Mon Sep 17 00:00:00 2001 From: morguldir Date: Sat, 5 Oct 2024 12:30:05 -0400 Subject: [PATCH 046/245] fix sliding sync room type filter regression Signed-off-by: strawberry --- src/api/client/sync.rs | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index c4ff1eeb5..5940d7cf2 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -6,7 +6,9 @@ use std::{ use axum::extract::State; use conduit::{ - debug, err, error, is_equal_to, + debug, err, error, + error::is_not_found, + is_equal_to, result::IntoIsOk, utils::{ math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, @@ -1887,18 +1889,21 @@ async fn filter_rooms( .iter() .stream() .filter_map(|r| async move { - match services.rooms.state_accessor.get_room_type(r).await { - Err(_) => false, - Ok(result) => { - let result = RoomTypeFilter::from(Some(result)); - if negate { - !filter.contains(&result) - } else { - filter.is_empty() || filter.contains(&result) - } - }, + let room_type = services.rooms.state_accessor.get_room_type(r).await; + + if room_type.as_ref().is_err_and(|e| !is_not_found(e)) { + return None; } - .then_some(r.to_owned()) + + let room_type_filter = RoomTypeFilter::from(room_type.ok()); + + let include = if negate { + !filter.contains(&room_type_filter) + } else { + filter.is_empty() || filter.contains(&room_type_filter) + }; + + include.then_some(r.to_owned()) }) .collect() .await From 8eec78e9e0e5076289e98b82b2bb9b4a139be70d Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 5 Oct 2024 12:51:15 -0400 Subject: [PATCH 047/245] mark the server user bot as online/offline on shutdown/startup Signed-off-by: strawberry --- src/service/services.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/service/services.rs b/src/service/services.rs index 0b63a5cae..ea81f434f 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -123,6 +123,14 @@ impl Services { .start() .await?; + // set the server user as online + if self.server.config.allow_local_presence { + _ = self + .presence + .ping_presence(&self.globals.server_user, &ruma::presence::PresenceState::Online) + .await; + } + debug_info!("Services startup complete."); Ok(Arc::clone(self)) } @@ -130,6 +138,14 @@ impl Services { pub async fn stop(&self) { info!("Shutting down services..."); + // set the server user as offline + if self.server.config.allow_local_presence { + _ = self + .presence + .ping_presence(&self.globals.server_user, &ruma::presence::PresenceState::Offline) + .await; + } + self.interrupt(); if let Some(manager) = self.manager.lock().await.as_ref() { manager.stop().await; From 814b9e28b68dbe6af91d3a397f6a631f1f4e9113 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 9 Oct 2024 03:37:13 +0000 Subject: [PATCH 048/245] fix unnecessary re-serializations Signed-off-by: Jason Volk --- src/database/map/rev_stream_from.rs | 3 +-- src/database/map/stream_from.rs | 3 +-- src/service/rooms/metadata/mod.rs | 4 ++-- src/service/rooms/pdu_metadata/data.rs | 2 +- src/service/rooms/short/mod.rs | 2 +- src/service/rooms/state_cache/mod.rs | 4 ++-- src/service/updates/mod.rs | 2 +- 7 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/database/map/rev_stream_from.rs b/src/database/map/rev_stream_from.rs index 650cf038c..c48f406b2 100644 --- a/src/database/map/rev_stream_from.rs +++ b/src/database/map/rev_stream_from.rs @@ -18,8 +18,7 @@ where K: Deserialize<'a> + Send, V: Deserialize<'a> + Send, { - let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); - self.rev_stream_raw_from(&key) + self.rev_stream_raw_from(from) .map(keyval::result_deserialize::) } diff --git a/src/database/map/stream_from.rs b/src/database/map/stream_from.rs index 153d5bb61..db8281250 100644 --- a/src/database/map/stream_from.rs +++ b/src/database/map/stream_from.rs @@ -18,8 +18,7 @@ where K: Deserialize<'a> + Send, V: Deserialize<'a> + Send, { - let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); - self.stream_raw_from(&key) + self.stream_raw_from(from) .map(keyval::result_deserialize::) } diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 5d4a47c71..d8be6aab6 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -85,8 +85,8 @@ pub fn list_banned_rooms(&self) -> impl Stream + Send + '_ { sel #[implement(Service)] #[inline] -pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.qry(room_id).await.is_ok() } +pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.get(room_id).await.is_ok() } #[implement(Service)] #[inline] -pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.qry(room_id).await.is_ok() } +pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.get(room_id).await.is_ok() } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index f23234752..8e0456582 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -94,6 +94,6 @@ impl Data { } pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { - self.softfailedeventids.qry(event_id).await.is_ok() + self.softfailedeventids.get(event_id).await.is_ok() } } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 20082da23..bd8fdcc94 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -197,7 +197,7 @@ pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, boo #[implement(Service)] pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { - self.db.roomid_shortroomid.qry(room_id).await.deserialized() + self.db.roomid_shortroomid.get(room_id).await.deserialized() } #[implement(Service)] diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index b1a71cafe..a6c468f5c 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -342,7 +342,7 @@ impl Service { /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self), level = "debug")] pub async fn room_joined_count(&self, room_id: &RoomId) -> Result { - self.db.roomid_joinedcount.qry(room_id).await.deserialized() + self.db.roomid_joinedcount.get(room_id).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] @@ -366,7 +366,7 @@ impl Service { pub async fn room_invited_count(&self, room_id: &RoomId) -> Result { self.db .roomid_invitedcount - .qry(room_id) + .get(room_id) .await .deserialized() } diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index 4e16e22b0..fca637255 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -128,7 +128,7 @@ impl Service { pub async fn last_check_for_updates_id(&self) -> u64 { self.db - .qry(LAST_CHECK_FOR_UPDATES_COUNT) + .get(LAST_CHECK_FOR_UPDATES_COUNT) .await .deserialized() .unwrap_or(0_u64) From 56dd0f51392cd3f21f62cc054838d6f01160f6de Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 6 Oct 2024 22:08:55 +0000 Subject: [PATCH 049/245] use loop condition to account for loole channel close Signed-off-by: Jason Volk --- src/service/presence/mod.rs | 7 ++++--- src/service/sending/sender.rs | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 3b5c4caf4..82a99bd56 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -55,14 +55,13 @@ impl crate::Service for Service { async fn worker(self: Arc) -> Result<()> { let mut presence_timers = FuturesUnordered::new(); let receiver = self.timer_receiver.lock().await; - loop { - debug_assert!(!receiver.is_closed(), "channel error"); + while !receiver.is_closed() { tokio::select! { Some(user_id) = presence_timers.next() => { self.process_presence_timer(&user_id).await.log_err().ok(); }, event = receiver.recv_async() => match event { - Err(_e) => return Ok(()), + Err(_) => break, Ok((user_id, timeout)) => { debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); presence_timers.push(presence_timer(user_id, timeout)); @@ -70,6 +69,8 @@ impl crate::Service for Service { }, } } + + Ok(()) } fn interrupt(&self) { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 3a401995b..19205a656 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -57,8 +57,7 @@ impl Service { let receiver = self.receiver.lock().await; self.initial_requests(&mut futures, &mut statuses).await; - loop { - debug_assert!(!receiver.is_closed(), "channel error"); + while !receiver.is_closed() { tokio::select! { request = receiver.recv_async() => match request { Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await, From 89a3c807002ea7f6278e30541df7f3249b8fc681 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 2 Oct 2024 21:18:32 +0000 Subject: [PATCH 050/245] split admin-room branch from build_and_append_pdu (fixes large stack warning) Signed-off-by: Jason Volk --- src/service/rooms/timeline/mod.rs | 146 ++++++++++++++---------------- 1 file changed, 68 insertions(+), 78 deletions(-) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index f8f770bc4..b49e9fad5 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -8,13 +8,13 @@ use std::{ }; use conduit::{ - debug, err, error, info, + debug, err, error, implement, info, pdu::{EventHash, PduBuilder, PduCount, PduEvent}, utils, utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, validated, warn, Err, Error, Result, Server, }; -use futures::{future, future::ready, Future, Stream, StreamExt, TryStreamExt}; +use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, @@ -858,82 +858,7 @@ impl Service { .await?; if self.services.admin.is_admin_room(&pdu.room_id).await { - match pdu.event_type() { - TimelineEventType::RoomEncryption => { - warn!("Encryption is not allowed in the admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Encryption is not allowed in the admins room", - )); - }, - TimelineEventType::RoomMember => { - let target = pdu - .state_key() - .filter(|v| v.starts_with('@')) - .unwrap_or(sender.as_str()); - let server_user = &self.services.globals.server_user.to_string(); - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu"))?; - - if content.membership == MembershipState::Leave { - if target == server_user { - warn!("Server user cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot leave from admins room.", - )); - } - - let count = self - .services - .state_cache - .room_members(&pdu.room_id) - .ready_filter(|user| self.services.globals.user_is_local(user)) - .ready_filter(|user| *user != target) - .boxed() - .count() - .await; - - if count < 2 { - warn!("Last admin cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot leave from admins room.", - )); - } - } - - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { - if target == server_user { - warn!("Server user cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot be banned in admins room.", - )); - } - - let count = self - .services - .state_cache - .room_members(&pdu.room_id) - .ready_filter(|user| self.services.globals.user_is_local(user)) - .ready_filter(|user| *user != target) - .boxed() - .count() - .await; - - if count < 2 { - warn!("Last admin cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot be banned in admins room.", - )); - } - } - }, - _ => {}, - } + self.check_pdu_for_admin_room(&pdu, sender).boxed().await?; } // If redaction event is not authorized, do not append it to the timeline @@ -1298,6 +1223,71 @@ impl Service { } } +#[implement(Service)] +#[tracing::instrument(skip_all, level = "debug")] +async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Result<()> { + match pdu.event_type() { + TimelineEventType::RoomEncryption => { + return Err!(Request(Forbidden(error!("Encryption not supported in admins room.")))); + }, + TimelineEventType::RoomMember => { + let target = pdu + .state_key() + .filter(|v| v.starts_with('@')) + .unwrap_or(sender.as_str()); + + let server_user = &self.services.globals.server_user.to_string(); + + let content: RoomMemberEventContent = pdu.get_content()?; + match content.membership { + MembershipState::Leave => { + if target == server_user { + return Err!(Request(Forbidden(error!("Server user cannot leave the admins room.")))); + } + + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + return Err!(Request(Forbidden(error!("Last admin cannot leave the admins room.")))); + } + }, + + MembershipState::Ban if pdu.state_key().is_some() => { + if target == server_user { + return Err!(Request(Forbidden(error!("Server cannot be banned from admins room.")))); + } + + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + return Err!(Request(Forbidden(error!("Last admin cannot be banned from admins room.")))); + } + }, + _ => {}, + }; + }, + _ => {}, + }; + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; From 08a2fecc0ed2e0404446d16e40bc136dfff7b7c7 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 26 Sep 2024 04:59:16 +0000 Subject: [PATCH 051/245] catch panics at base functions to integrate with other fatal errors. Signed-off-by: Jason Volk --- Cargo.lock | 1 + src/router/Cargo.toml | 13 +++++++------ src/router/mod.rs | 23 ++++++++++++++++++----- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e72c7e805..b9f366e79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -763,6 +763,7 @@ dependencies = [ "conduit_core", "conduit_service", "const-str", + "futures", "http", "http-body-util", "hyper", diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml index 62690194e..e15358687 100644 --- a/src/router/Cargo.toml +++ b/src/router/Cargo.toml @@ -54,20 +54,18 @@ axum-server-dual-protocol.workspace = true axum-server-dual-protocol.optional = true axum-server.workspace = true axum.workspace = true +bytes.workspace = true conduit-admin.workspace = true conduit-api.workspace = true conduit-core.workspace = true conduit-service.workspace = true const-str.workspace = true -log.workspace = true -tokio.workspace = true -tower.workspace = true -tracing.workspace = true -bytes.workspace = true -http-body-util.workspace = true +futures.workspace = true http.workspace = true +http-body-util.workspace = true hyper.workspace = true hyper-util.workspace = true +log.workspace = true ruma.workspace = true rustls.workspace = true rustls.optional = true @@ -78,7 +76,10 @@ sentry-tracing.optional = true sentry-tracing.workspace = true sentry.workspace = true serde_json.workspace = true +tokio.workspace = true +tower.workspace = true tower-http.workspace = true +tracing.workspace = true [target.'cfg(unix)'.dependencies] sd-notify.workspace = true diff --git a/src/router/mod.rs b/src/router/mod.rs index e123442ca..1580f6051 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -6,10 +6,11 @@ mod serve; extern crate conduit_core as conduit; -use std::{future::Future, pin::Pin, sync::Arc}; +use std::{panic::AssertUnwindSafe, pin::Pin, sync::Arc}; -use conduit::{Result, Server}; +use conduit::{Error, Result, Server}; use conduit_service::Services; +use futures::{Future, FutureExt, TryFutureExt}; conduit::mod_ctor! {} conduit::mod_dtor! {} @@ -17,15 +18,27 @@ conduit::rustc_flags_capture! {} #[no_mangle] pub extern "Rust" fn start(server: &Arc) -> Pin>> + Send>> { - Box::pin(run::start(server.clone())) + AssertUnwindSafe(run::start(server.clone())) + .catch_unwind() + .map_err(Error::from_panic) + .unwrap_or_else(Err) + .boxed() } #[no_mangle] pub extern "Rust" fn stop(services: Arc) -> Pin> + Send>> { - Box::pin(run::stop(services)) + AssertUnwindSafe(run::stop(services)) + .catch_unwind() + .map_err(Error::from_panic) + .unwrap_or_else(Err) + .boxed() } #[no_mangle] pub extern "Rust" fn run(services: &Arc) -> Pin> + Send>> { - Box::pin(run::run(services.clone())) + AssertUnwindSafe(run::run(services.clone())) + .catch_unwind() + .map_err(Error::from_panic) + .unwrap_or_else(Err) + .boxed() } From a2e5c3d5d3bc9253fb634bf8b1b30ec7087e886f Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 00:30:16 +0000 Subject: [PATCH 052/245] add FlatOk trait to Result/Option suite Signed-off-by: Jason Volk --- src/core/result.rs | 3 ++- src/core/result/flat_ok.rs | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 src/core/result/flat_ok.rs diff --git a/src/core/result.rs b/src/core/result.rs index 82d67a9c5..9a60d19e2 100644 --- a/src/core/result.rs +++ b/src/core/result.rs @@ -1,4 +1,5 @@ mod debug_inspect; +mod flat_ok; mod into_is_ok; mod log_debug_err; mod log_err; @@ -7,7 +8,7 @@ mod not_found; mod unwrap_infallible; pub use self::{ - debug_inspect::DebugInspect, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, log_err::LogErr, + debug_inspect::DebugInspect, flat_ok::FlatOk, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, log_err::LogErr, map_expect::MapExpect, not_found::NotFound, unwrap_infallible::UnwrapInfallible, }; diff --git a/src/core/result/flat_ok.rs b/src/core/result/flat_ok.rs new file mode 100644 index 000000000..e378e5d05 --- /dev/null +++ b/src/core/result/flat_ok.rs @@ -0,0 +1,34 @@ +use super::Result; + +pub trait FlatOk { + /// Equivalent to .transpose().ok().flatten() + fn flat_ok(self) -> Option; + + /// Equivalent to .transpose().ok().flatten().ok_or(...) + fn flat_ok_or(self, err: E) -> Result; + + /// Equivalent to .transpose().ok().flatten().ok_or_else(...) + fn flat_ok_or_else E>(self, err: F) -> Result; +} + +impl FlatOk for Option> { + #[inline] + fn flat_ok(self) -> Option { self.transpose().ok().flatten() } + + #[inline] + fn flat_ok_or(self, err: Ep) -> Result { self.flat_ok().ok_or(err) } + + #[inline] + fn flat_ok_or_else Ep>(self, err: F) -> Result { self.flat_ok().ok_or_else(err) } +} + +impl FlatOk for Result, E> { + #[inline] + fn flat_ok(self) -> Option { self.ok().flatten() } + + #[inline] + fn flat_ok_or(self, err: Ep) -> Result { self.flat_ok().ok_or(err) } + + #[inline] + fn flat_ok_or_else Ep>(self, err: F) -> Result { self.flat_ok().ok_or_else(err) } +} From 4485f36e34d9da010b37d0db832ac6e38c794e7e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 01:00:00 +0000 Subject: [PATCH 053/245] add mactors for true/false Signed-off-by: Jason Volk --- src/core/mod.rs | 56 ++++++++++++++++++++++++++++++++++++++++++ src/core/utils/math.rs | 32 ------------------------ 2 files changed, 56 insertions(+), 32 deletions(-) diff --git a/src/core/mod.rs b/src/core/mod.rs index e45531864..491d8b4ce 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -38,3 +38,59 @@ pub mod mods { () => {}; } } + +/// Functor for falsy +#[macro_export] +macro_rules! is_false { + () => { + |x| !x + }; +} + +/// Functor for truthy +#[macro_export] +macro_rules! is_true { + () => { + |x| !!x + }; +} + +/// Functor for equality to zero +#[macro_export] +macro_rules! is_zero { + () => { + $crate::is_matching!(0) + }; +} + +/// Functor for equality i.e. .is_some_and(is_equal!(2)) +#[macro_export] +macro_rules! is_equal_to { + ($val:expr) => { + |x| x == $val + }; +} + +/// Functor for less i.e. .is_some_and(is_less_than!(2)) +#[macro_export] +macro_rules! is_less_than { + ($val:expr) => { + |x| x < $val + }; +} + +/// Functor for matches! i.e. .is_some_and(is_matching!('A'..='Z')) +#[macro_export] +macro_rules! is_matching { + ($val:expr) => { + |x| matches!(x, $val) + }; +} + +/// Functor for !is_empty() +#[macro_export] +macro_rules! is_not_empty { + () => { + |x| !x.is_empty() + }; +} diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs index 215de339c..ccff6400d 100644 --- a/src/core/utils/math.rs +++ b/src/core/utils/math.rs @@ -53,38 +53,6 @@ macro_rules! validated { ($($input:tt)+) => { $crate::expected!($($input)+) } } -/// Functor for equality to zero -#[macro_export] -macro_rules! is_zero { - () => { - $crate::is_matching!(0) - }; -} - -/// Functor for equality i.e. .is_some_and(is_equal!(2)) -#[macro_export] -macro_rules! is_equal_to { - ($val:expr) => { - |x| (x == $val) - }; -} - -/// Functor for less i.e. .is_some_and(is_less_than!(2)) -#[macro_export] -macro_rules! is_less_than { - ($val:expr) => { - |x| (x < $val) - }; -} - -/// Functor for matches! i.e. .is_some_and(is_matching!('A'..='Z')) -#[macro_export] -macro_rules! is_matching { - ($val:expr) => { - |x| matches!(x, $val) - }; -} - /// Returns false if the exponential backoff has expired based on the inputs #[inline] #[must_use] From dd9f53080acb354905a610dc235c8720c72d742c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 08:04:16 +0000 Subject: [PATCH 054/245] add unwrap_or to TryFutureExtExt Signed-off-by: Jason Volk --- src/core/utils/future/try_ext_ext.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/core/utils/future/try_ext_ext.rs b/src/core/utils/future/try_ext_ext.rs index e444ad94a..d30d2cac7 100644 --- a/src/core/utils/future/try_ext_ext.rs +++ b/src/core/utils/future/try_ext_ext.rs @@ -1,6 +1,9 @@ //! Extended external extensions to futures::TryFutureExt -use futures::{future::MapOkOrElse, TryFuture, TryFutureExt}; +use futures::{ + future::{MapOkOrElse, UnwrapOrElse}, + TryFuture, TryFutureExt, +}; /// This interface is not necessarily complete; feel free to add as-needed. pub trait TryExtExt @@ -19,6 +22,10 @@ where ) -> MapOkOrElse Option, impl FnOnce(Self::Error) -> Option> where Self: Sized; + + fn unwrap_or(self, default: Self::Ok) -> UnwrapOrElse Self::Ok> + where + Self: Sized; } impl TryExtExt for Fut @@ -45,4 +52,12 @@ where { self.map_ok_or(None, Some) } + + #[inline] + fn unwrap_or(self, default: Self::Ok) -> UnwrapOrElse Self::Ok> + where + Self: Sized, + { + self.unwrap_or_else(move |_| default) + } } From 685eadb1713d0d09a48025b74f05407cd6f65742 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 4 Oct 2024 17:07:31 +0000 Subject: [PATCH 055/245] add is_not_found as Error member function; tweak interface; add doc comments Signed-off-by: Jason Volk --- src/api/client/sync.rs | 8 +++----- src/api/server/send.rs | 4 ++-- src/core/error/mod.rs | 21 +++++++++++++++++---- src/core/result/not_found.rs | 4 ++-- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 5940d7cf2..f0b26e800 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -6,10 +6,8 @@ use std::{ use axum::extract::State; use conduit::{ - debug, err, error, - error::is_not_found, - is_equal_to, - result::IntoIsOk, + debug, err, error, is_equal_to, + result::{FlatOk, IntoIsOk}, utils::{ math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, BoolExt, IterStream, ReadyExt, TryFutureExtExt, @@ -1891,7 +1889,7 @@ async fn filter_rooms( .filter_map(|r| async move { let room_type = services.rooms.state_accessor.get_room_type(r).await; - if room_type.as_ref().is_err_and(|e| !is_not_found(e)) { + if room_type.as_ref().is_err_and(|e| !e.is_not_found()) { return None; } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index bb4249881..50a79e002 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant}; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug, debug_warn, err, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result}; +use conduit::{debug, debug_warn, err, error, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result}; use futures::StreamExt; use ruma::{ api::{ @@ -85,7 +85,7 @@ pub(crate) async fn send_transaction_message_route( Ok(send_transaction_message::v1::Response { pdus: resolved_map .into_iter() - .map(|(e, r)| (e, r.map_err(|e| e.sanitized_string()))) + .map(|(e, r)| (e, r.map_err(error::sanitized_message))) .collect(), }) } diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index ad7f9f3ca..39fa43404 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -120,17 +120,19 @@ pub enum Error { } impl Error { + //#[deprecated] pub fn bad_database(message: &'static str) -> Self { crate::err!(Database(error!("{message}"))) } /// Sanitizes public-facing errors that can leak sensitive information. - pub fn sanitized_string(&self) -> String { + pub fn sanitized_message(&self) -> String { match self { Self::Database(..) => String::from("Database error occurred."), Self::Io(..) => String::from("I/O error occurred."), - _ => self.to_string(), + _ => self.message(), } } + /// Generate the error message string. pub fn message(&self) -> String { match self { Self::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), @@ -151,6 +153,8 @@ impl Error { } } + /// Returns the HTTP error code or closest approximation based on error + /// variant. pub fn status_code(&self) -> http::StatusCode { use http::StatusCode; @@ -163,10 +167,17 @@ impl Error { _ => StatusCode::INTERNAL_SERVER_ERROR, } } + + /// Returns true for "not found" errors. This means anything that qualifies + /// as a "not found" from any variant's contained error type. This call is + /// often used as a special case to eliminate a contained Option with a + /// Result where Ok(None) is instead Err(e) if e.is_not_found(). + #[inline] + pub fn is_not_found(&self) -> bool { self.status_code() == http::StatusCode::NOT_FOUND } } impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.message()) } } #[allow(clippy::fallible_impl_from)] @@ -184,6 +195,8 @@ pub fn infallible(_e: &Infallible) { panic!("infallible error should never exist"); } +/// Convenience functor for fundamental Error::sanitized_message(); see member. #[inline] #[must_use] -pub fn is_not_found(e: &Error) -> bool { e.status_code() == http::StatusCode::NOT_FOUND } +#[allow(clippy::needless_pass_by_value)] +pub fn sanitized_message(e: Error) -> String { e.sanitized_message() } diff --git a/src/core/result/not_found.rs b/src/core/result/not_found.rs index 69ce821b8..d61825afa 100644 --- a/src/core/result/not_found.rs +++ b/src/core/result/not_found.rs @@ -1,5 +1,5 @@ use super::Result; -use crate::{error, Error}; +use crate::Error; pub trait NotFound { #[must_use] @@ -8,5 +8,5 @@ pub trait NotFound { impl NotFound for Result { #[inline] - fn is_not_found(&self) -> bool { self.as_ref().is_err_and(error::is_not_found) } + fn is_not_found(&self) -> bool { self.as_ref().is_err_and(Error::is_not_found) } } From 2b2055fe8a47ba9dd1981237ee130923766938f4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 4 Oct 2024 03:40:00 +0000 Subject: [PATCH 056/245] parallelize calculate_invite_state Signed-off-by: Jason Volk --- src/api/client/membership.rs | 2 +- src/service/rooms/state/mod.rs | 80 ++++++++++--------------------- src/service/rooms/timeline/mod.rs | 5 +- 3 files changed, 27 insertions(+), 60 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index f89903b4f..ae56094ce 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1452,7 +1452,7 @@ pub(crate) async fn invite_helper( ) .await?; - let invite_room_state = services.rooms.state.calculate_invite_state(&pdu).await?; + let invite_room_state = services.rooms.state.summary_stripped(&pdu).await; drop(state_lock); diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index c7f6605c7..177b7e9b2 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -3,6 +3,7 @@ mod data; use std::{ collections::{HashMap, HashSet}, fmt::Write, + iter::once, sync::Arc, }; @@ -13,7 +14,7 @@ use conduit::{ }; use data::Data; use database::{Ignore, Interfix}; -use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use futures::{future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ events::{ room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, @@ -288,61 +289,30 @@ impl Service { } } - #[tracing::instrument(skip(self, invite_event), level = "debug")] - pub async fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { - let mut state = Vec::new(); - // Add recommended events - if let Ok(e) = self - .services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "") - .await - { - state.push(e.to_stripped_state_event()); - } - if let Ok(e) = self - .services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "") - .await - { - state.push(e.to_stripped_state_event()); - } - if let Ok(e) = self - .services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomCanonicalAlias, "") - .await - { - state.push(e.to_stripped_state_event()); - } - if let Ok(e) = self - .services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "") - .await - { - state.push(e.to_stripped_state_event()); - } - if let Ok(e) = self - .services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "") - .await - { - state.push(e.to_stripped_state_event()); - } - if let Ok(e) = self - .services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str()) - .await - { - state.push(e.to_stripped_state_event()); - } + #[tracing::instrument(skip_all, level = "debug")] + pub async fn summary_stripped(&self, invite: &PduEvent) -> Vec> { + let cells = [ + (&StateEventType::RoomCreate, ""), + (&StateEventType::RoomJoinRules, ""), + (&StateEventType::RoomCanonicalAlias, ""), + (&StateEventType::RoomName, ""), + (&StateEventType::RoomAvatar, ""), + (&StateEventType::RoomMember, invite.sender.as_str()), // Add recommended events + ]; + + let fetches = cells.iter().map(|(event_type, state_key)| { + self.services + .state_accessor + .room_state_get(&invite.room_id, event_type, state_key) + }); - state.push(invite_event.to_stripped_state_event()); - Ok(state) + join_all(fetches) + .await + .into_iter() + .filter_map(Result::ok) + .map(|e| e.to_stripped_state_event()) + .chain(once(invite.to_stripped_state_event())) + .collect() } /// Set the state hash to a new version, but does not update state_cache. diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index b49e9fad5..84f29c865 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -513,10 +513,7 @@ impl Service { })?; let invite_state = match content.membership { - MembershipState::Invite => { - let state = self.services.state.calculate_invite_state(pdu).await?; - Some(state) - }, + MembershipState::Invite => self.services.state.summary_stripped(pdu).await.into(), _ => None, }; From 48a767d52c3d24c6a460a6defdddaa9c7c707387 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 01:01:25 +0000 Subject: [PATCH 057/245] abstract common patterns as core pdu memberfns Signed-off-by: Jason Volk --- src/core/pdu/mod.rs | 67 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index cf9ffe645..a94e2bdc6 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -18,11 +18,11 @@ use ruma::{ use serde::{Deserialize, Serialize}; use serde_json::{ json, - value::{to_raw_value, RawValue as RawJsonValue}, + value::{to_raw_value, RawValue as RawJsonValue, Value as JsonValue}, }; pub use self::{builder::PduBuilder, count::PduCount}; -use crate::{err, warn, Error, Result}; +use crate::{err, is_true, warn, Error, Result}; #[derive(Deserialize)] struct ExtractRedactedBecause { @@ -58,8 +58,8 @@ pub struct PduEvent { pub unsigned: Option>, pub hashes: EventHash, #[serde(default, skip_serializing_if = "Option::is_none")] - pub signatures: Option>, /* BTreeMap, BTreeMap> */ + // BTreeMap, BTreeMap> + pub signatures: Option>, } impl PduEvent { @@ -170,6 +170,54 @@ impl PduEvent { (self.redacts.clone(), self.content.clone()) } + #[must_use] + pub fn get_content_as_value(&self) -> JsonValue { + self.get_content() + .expect("pdu content must be a valid JSON value") + } + + pub fn get_content(&self) -> Result + where + T: for<'de> Deserialize<'de>, + { + serde_json::from_str(self.content.get()) + .map_err(|e| err!(Database("Failed to deserialize pdu content into type: {e}"))) + } + + pub fn contains_unsigned_property(&self, property: &str, is_type: F) -> bool + where + F: FnOnce(&JsonValue) -> bool, + { + self.get_unsigned_as_value() + .get(property) + .map(is_type) + .is_some_and(is_true!()) + } + + pub fn get_unsigned_property(&self, property: &str) -> Result + where + T: for<'de> Deserialize<'de>, + { + self.get_unsigned_as_value() + .get_mut(property) + .map(JsonValue::take) + .map(serde_json::from_value) + .ok_or(err!(Request(NotFound("property not found in unsigned object"))))? + .map_err(|e| err!(Database("Failed to deserialize unsigned.{property} into type: {e}"))) + } + + #[must_use] + pub fn get_unsigned_as_value(&self) -> JsonValue { self.get_unsigned::().unwrap_or_default() } + + pub fn get_unsigned(&self) -> Result { + self.unsigned + .as_ref() + .map(|raw| raw.get()) + .map(serde_json::from_str) + .ok_or(err!(Request(NotFound("\"unsigned\" property not found in pdu"))))? + .map_err(|e| err!(Database("Failed to deserialize \"unsigned\" into value: {e}"))) + } + #[tracing::instrument(skip(self), level = "debug")] pub fn to_sync_room_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); @@ -270,8 +318,8 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_state_event(&self) -> Raw { + #[must_use] + pub fn to_state_event_value(&self) -> JsonValue { let mut json = json!({ "content": self.content, "type": self.kind, @@ -286,7 +334,12 @@ impl PduEvent { json["unsigned"] = json!(unsigned); } - serde_json::from_value(json).expect("Raw::from_value always works") + json + } + + #[tracing::instrument(skip(self), level = "debug")] + pub fn to_state_event(&self) -> Raw { + serde_json::from_value(self.to_state_event_value()).expect("Raw::from_value always works") } #[tracing::instrument(skip(self), level = "debug")] From da34b43302d8e0d66dc218a1612e5a6eb18cb710 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 2 Oct 2024 07:57:18 +0000 Subject: [PATCH 058/245] abstract account-data deserializations for serde_json::from_elim Signed-off-by: Jason Volk --- src/admin/query/account_data.rs | 9 +-- src/admin/user/commands.rs | 42 ++++------ src/api/client/config.rs | 20 ++--- src/api/client/push.rs | 117 ++++++++++----------------- src/api/client/tag.rs | 68 ++++++---------- src/service/account_data/mod.rs | 39 ++++++--- src/service/admin/grant.rs | 3 +- src/service/globals/migrations.rs | 11 +-- src/service/rooms/state_cache/mod.rs | 14 +--- src/service/rooms/timeline/mod.rs | 3 +- src/service/sending/sender.rs | 3 +- src/service/users/mod.rs | 14 +--- 12 files changed, 133 insertions(+), 210 deletions(-) diff --git a/src/admin/query/account_data.rs b/src/admin/query/account_data.rs index 896bf95cf..ea45eb166 100644 --- a/src/admin/query/account_data.rs +++ b/src/admin/query/account_data.rs @@ -1,9 +1,6 @@ use clap::Subcommand; use conduit::Result; -use ruma::{ - events::{room::message::RoomMessageEventContent, RoomAccountDataEventType}, - RoomId, UserId, -}; +use ruma::{events::room::message::RoomMessageEventContent, RoomId, UserId}; use crate::Command; @@ -25,7 +22,7 @@ pub(crate) enum AccountDataCommand { /// Full user ID user_id: Box, /// Account data event type - kind: RoomAccountDataEventType, + kind: String, /// Optional room ID of the account data room_id: Option>, }, @@ -60,7 +57,7 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .get(room_id.as_deref(), &user_id, kind) + .get_raw(room_id.as_deref(), &user_id, &kind) .await; let query_time = timer.elapsed(); diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 1b086856a..562bb9c74 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -501,20 +501,16 @@ pub(super) async fn put_room_tag( ) -> Result { let user_id = parse_active_local_user_id(self.services, &user_id).await?; - let event = self + let mut tags_event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) - .await; - - let mut tags_event = event.map_or_else( - |_| TagEvent { + .get_room(&room_id, &user_id, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }, - |e| serde_json::from_str(e.get()).expect("Bad account data in database for user {user_id}"), - ); + }); tags_event .content @@ -542,20 +538,16 @@ pub(super) async fn delete_room_tag( ) -> Result { let user_id = parse_active_local_user_id(self.services, &user_id).await?; - let event = self + let mut tags_event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) - .await; - - let mut tags_event = event.map_or_else( - |_| TagEvent { + .get_room(&room_id, &user_id, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }, - |e| serde_json::from_str(e.get()).expect("Bad account data in database for user {user_id}"), - ); + }); tags_event.content.tags.remove(&tag.clone().into()); @@ -578,20 +570,16 @@ pub(super) async fn delete_room_tag( pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box) -> Result { let user_id = parse_active_local_user_id(self.services, &user_id).await?; - let event = self + let tags_event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) - .await; - - let tags_event = event.map_or_else( - |_| TagEvent { + .get_room(&room_id, &user_id, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }, - |e| serde_json::from_str(e.get()).expect("Bad account data in database for user {user_id}"), - ); + }); Ok(RoomMessageEventContent::notice_markdown(format!( "```\n{:#?}\n```", diff --git a/src/api/client/config.rs b/src/api/client/config.rs index 33b85136c..d06cc0729 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -58,18 +58,14 @@ pub(crate) async fn get_global_account_data_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = services + let account_data: ExtractGlobalEventContent = services .account_data - .get(None, sender_user, body.event_type.to_string().into()) + .get_global(sender_user, body.event_type.clone()) .await .map_err(|_| err!(Request(NotFound("Data not found."))))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; - Ok(get_global_account_data::v3::Response { - account_data, + account_data: account_data.content, }) } @@ -81,18 +77,14 @@ pub(crate) async fn get_room_account_data_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = services + let account_data: ExtractRoomEventContent = services .account_data - .get(Some(&body.room_id), sender_user, body.event_type.clone()) + .get_room(&body.room_id, sender_user, body.event_type.clone()) .await .map_err(|_| err!(Request(NotFound("Data not found."))))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; - Ok(get_room_account_data::v3::Response { - account_data, + account_data: account_data.content, }) } diff --git a/src/api/client/push.rs b/src/api/client/push.rs index 390951999..103c0c5e1 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -13,7 +13,7 @@ use ruma::{ GlobalAccountDataEventType, }, push::{InsertPushRuleError, RemovePushRuleError, Ruleset}, - CanonicalJsonObject, + CanonicalJsonObject, CanonicalJsonValue, }; use service::Services; @@ -27,38 +27,23 @@ pub(crate) async fn get_pushrules_all_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let global_ruleset: Ruleset; - - let event = services + let Some(content_value) = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) - .await; - - let Ok(event) = event else { + .get_global::(sender_user, GlobalAccountDataEventType::PushRules) + .await + .ok() + .and_then(|event| event.get("content").cloned()) + .filter(CanonicalJsonValue::is_object) + else { // user somehow has non-existent push rule event. recreate it and return server // default silently return recreate_push_rules_and_return(&services, sender_user).await; }; - let value = serde_json::from_str::(event.get()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - - let Some(content_value) = value.get("content") else { - // user somehow has a push rule event with no content key, recreate it and - // return server default silently - return recreate_push_rules_and_return(&services, sender_user).await; - }; - - if content_value.to_string().is_empty() { - // user somehow has a push rule event with empty content, recreate it and return - // server default silently - return recreate_push_rules_and_return(&services, sender_user).await; - } - - let account_data_content = serde_json::from_value::(content_value.clone().into()) + let account_data_content = serde_json::from_value::(content_value.into()) .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - global_ruleset = account_data_content.global; + let global_ruleset: Ruleset = account_data_content.global; Ok(get_pushrules_all::v3::Response { global: global_ruleset, @@ -73,17 +58,14 @@ pub(crate) async fn get_pushrule_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services + let event: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(sender_user, GlobalAccountDataEventType::PushRules) .await - .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; - let rule = account_data + let rule = event + .content .global .get(body.kind.clone(), &body.rule_id) .map(Into::into); @@ -113,14 +95,11 @@ pub(crate) async fn set_pushrule_route( )); } - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(sender_user, GlobalAccountDataEventType::PushRules) .await - .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if let Err(error) = account_data @@ -181,21 +160,18 @@ pub(crate) async fn get_pushrule_actions_route( )); } - let event = services + let event: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(sender_user, GlobalAccountDataEventType::PushRules) .await - .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; - let global = account_data.global; - let actions = global + let actions = event + .content + .global .get(body.kind.clone(), &body.rule_id) .map(|rule| rule.actions().to_owned()) - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?; + .ok_or(err!(Request(NotFound("Push rule not found."))))?; Ok(get_pushrule_actions::v3::Response { actions, @@ -217,14 +193,11 @@ pub(crate) async fn set_pushrule_actions_route( )); } - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(sender_user, GlobalAccountDataEventType::PushRules) .await - .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if account_data .content @@ -263,20 +236,18 @@ pub(crate) async fn get_pushrule_enabled_route( )); } - let event = services + let event: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(sender_user, GlobalAccountDataEventType::PushRules) .await - .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - - let global = account_data.content.global; - let enabled = global + let enabled = event + .content + .global .get(body.kind.clone(), &body.rule_id) .map(ruma::push::AnyPushRuleRef::enabled) - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?; + .ok_or(err!(Request(NotFound("Push rule not found."))))?; Ok(get_pushrule_enabled::v3::Response { enabled, @@ -298,14 +269,11 @@ pub(crate) async fn set_pushrule_enabled_route( )); } - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(sender_user, GlobalAccountDataEventType::PushRules) .await - .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if account_data .content @@ -344,14 +312,11 @@ pub(crate) async fn delete_pushrule_route( )); } - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(sender_user, GlobalAccountDataEventType::PushRules) .await - .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if let Err(error) = account_data .content diff --git a/src/api/client/tag.rs b/src/api/client/tag.rs index bcd0f8170..b5fa19e3a 100644 --- a/src/api/client/tag.rs +++ b/src/api/client/tag.rs @@ -9,7 +9,7 @@ use ruma::{ }, }; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// @@ -21,21 +21,15 @@ pub(crate) async fn update_tag_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services + let mut tags_event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) - .await; - - let mut tags_event = event.map_or_else( - |_| { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - }, - |e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")), - )?; + .get_room(&body.room_id, sender_user, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }); tags_event .content @@ -65,21 +59,15 @@ pub(crate) async fn delete_tag_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services + let mut tags_event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) - .await; - - let mut tags_event = event.map_or_else( - |_| { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - }, - |e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")), - )?; + .get_room(&body.room_id, sender_user, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }); tags_event.content.tags.remove(&body.tag.clone().into()); @@ -106,21 +94,15 @@ pub(crate) async fn get_tags_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services + let tags_event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) - .await; - - let tags_event = event.map_or_else( - |_| { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - }, - |e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")), - )?; + .get_room(&body.room_id, sender_user, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }); Ok(get_tags::v3::Response { tags: tags_event.content.tags, diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 482229e7f..8065ac55b 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -5,14 +5,17 @@ use conduit::{ utils::{stream::TryIgnore, ReadyExt}, Err, Error, Result, }; -use database::{Deserialized, Map}; +use database::{Deserialized, Handle, Map}; use futures::{StreamExt, TryFutureExt}; use ruma::{ - events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, + events::{ + AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, GlobalAccountDataEventType, + RoomAccountDataEventType, + }, serde::Raw, RoomId, UserId, }; -use serde_json::value::RawValue; +use serde::Deserialize; use crate::{globals, Dep}; @@ -97,18 +100,36 @@ pub async fn update( Ok(()) } -/// Searches the account data for a specific kind. +/// Searches the room account data for a specific kind. #[implement(Service)] -pub async fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, -) -> Result> { - let key = (room_id, user_id, kind.to_string()); +pub async fn get_global(&self, user_id: &UserId, kind: GlobalAccountDataEventType) -> Result +where + T: for<'de> Deserialize<'de>, +{ + self.get_raw(None, user_id, &kind.to_string()) + .await + .deserialized() +} + +/// Searches the global account data for a specific kind. +#[implement(Service)] +pub async fn get_room(&self, room_id: &RoomId, user_id: &UserId, kind: RoomAccountDataEventType) -> Result +where + T: for<'de> Deserialize<'de>, +{ + self.get_raw(Some(room_id), user_id, &kind.to_string()) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn get_raw(&self, room_id: Option<&RoomId>, user_id: &UserId, kind: &str) -> Result> { + let key = (room_id, user_id, kind.to_owned()); self.db .roomusertype_roomuserdataid .qry(&key) .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.get(&roomuserdataid)) .await - .deserialized() } /// Returns all changes to the account data that happened after `since`. diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index 4b3ebb887..6e266ca9b 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -143,9 +143,8 @@ async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> R let mut event = self .services .account_data - .get(Some(room_id), user_id, RoomAccountDataEventType::Tag) + .get_room(room_id, user_id, RoomAccountDataEventType::Tag) .await - .and_then(|event| serde_json::from_str(event.get()).map_err(Into::into)) .unwrap_or_else(|_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index fc6e477b3..334e71c6f 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -215,13 +215,12 @@ async fn db_lt_12(services: &Services) -> Result<()> { }, }; - let raw_rules_list = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(&user, GlobalAccountDataEventType::PushRules) .await .expect("Username is invalid"); - let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); let rules_list = &mut account_data.content.global; //content rule @@ -294,14 +293,12 @@ async fn db_lt_13(services: &Services) -> Result<()> { }, }; - let raw_rules_list = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(&user, GlobalAccountDataEventType::PushRules) .await .expect("Username is invalid"); - let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); - let user_default_rules = Ruleset::server_default(&user); account_data .content diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index a6c468f5c..8539c9402 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -146,12 +146,9 @@ impl Service { if let Ok(tag_event) = self .services .account_data - .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag) + .get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag) .await - .and_then(|event| { - serde_json::from_str(event.get()) - .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) - }) { + { self.services .account_data .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event) @@ -163,12 +160,9 @@ impl Service { if let Ok(mut direct_event) = self .services .account_data - .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into()) + .get_global::(user_id, GlobalAccountDataEventType::Direct) .await - .and_then(|event| { - serde_json::from_str::(event.get()) - .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) - }) { + { let mut room_ids_updated = false; for room_ids in direct_event.content.0.values_mut() { if room_ids.iter().any(|r| r == &predecessor.room_id) { diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 84f29c865..7cf06522d 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -407,9 +407,8 @@ impl Service { let rules_for_user = self .services .account_data - .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(user, GlobalAccountDataEventType::PushRules) .await - .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) .map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); let mut highlight = false; diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 19205a656..90977abe9 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -539,9 +539,8 @@ impl Service { let rules_for_user = self .services .account_data - .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) + .get_global(userid, GlobalAccountDataEventType::PushRules) .await - .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) .map_or_else( |_| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global, diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 44d169dd4..3ab6b3c33 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -98,19 +98,9 @@ impl Service { pub async fn user_is_ignored(&self, sender_user: &UserId, recipient_user: &UserId) -> bool { self.services .account_data - .get( - None, - recipient_user, - GlobalAccountDataEventType::IgnoredUserList - .to_string() - .into(), - ) + .get_global(recipient_user, GlobalAccountDataEventType::IgnoredUserList) .await - .and_then(|event| { - serde_json::from_str::(event.get()) - .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) - }) - .map_or(false, |ignored| { + .map_or(false, |ignored: IgnoredUserListEvent| { ignored .content .ignored_users From 68315ac1128196c46802216054b31ec517dbfcb2 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 09:38:10 +0000 Subject: [PATCH 059/245] Add state_get_content(shortid) for serde_json::from elim Signed-off-by: Jason Volk --- src/api/client/sync.rs | 39 ++++++------- src/service/rooms/state_accessor/mod.rs | 78 ++++++++++--------------- 2 files changed, 48 insertions(+), 69 deletions(-) diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index f0b26e800..65d62a786 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -14,7 +14,7 @@ use conduit::{ }, warn, PduCount, }; -use futures::{pin_mut, FutureExt, StreamExt, TryFutureExt}; +use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -681,20 +681,22 @@ async fn load_joined_room( )) }; - let since_sender_member: Option = if let Some(short) = since_shortstatehash { + let get_sender_member_content = |short| { services .rooms .state_accessor - .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) - .await - .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) + .state_get_content(short, &StateEventType::RoomMember, sender_user.as_str()) .ok() - } else { - None }; - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + let since_sender_member: OptionFuture<_> = since_shortstatehash.map(get_sender_member_content).into(); + + let joined_since_last_sync = since_sender_member + .await + .flatten() + .map_or(true, |content: RoomMemberEventContent| { + content.membership != MembershipState::Join + }); if since_shortstatehash.is_none() || joined_since_last_sync { // Probably since = 0, we will do an initial sync @@ -1296,18 +1298,6 @@ pub(crate) async fn sync_events_v4_route( .await .ok(); - let since_sender_member: Option = if let Some(short) = since_shortstatehash { - services - .rooms - .state_accessor - .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) - .await - .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) - .ok() - } else { - None - }; - let encrypted_room = services .rooms .state_accessor @@ -1327,6 +1317,13 @@ pub(crate) async fn sync_events_v4_route( .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") .await; + let since_sender_member: Option = services + .rooms + .state_accessor + .state_get_content(since_shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) + .ok() + .await; + let joined_since_last_sync = since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 4c28483cb..ece8679d3 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -33,8 +33,8 @@ use ruma::{ }, room::RoomType, space::SpaceRoomJoinRule, - EventEncryptionAlgorithm, EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, - UserId, + EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, + ServerName, UserId, }; use serde::Deserialize; use serde_json::value::to_raw_value; @@ -125,16 +125,23 @@ impl Service { .await } + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub async fn state_get_content( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result + where + T: for<'de> Deserialize<'de> + Send, + { + self.state_get(shortstatehash, event_type, state_key) + .await + .and_then(|event| event.get_content()) + } + /// Get membership for given user in state async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState { - self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) + self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) .await - .map_or(MembershipState::Leave, |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomMemberEventContent| c.membership) - .map_err(|_| Error::bad_database("Invalid room membership event in database.")) - .unwrap() - }) + .map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership) } /// The user was a joined member at this state (potentially in the past) @@ -171,19 +178,10 @@ impl Service { } let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") .await - .map_or(HistoryVisibility::Shared, |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|e| { - error!( - "Invalid history visibility event in database for room {room_id}, assuming is \"shared\": \ - {e}" - ); - Error::bad_database("Invalid history visibility event in database.") - }) - .unwrap() + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility }); let current_server_members = self @@ -240,19 +238,10 @@ impl Service { let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") .await - .map_or(HistoryVisibility::Shared, |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|e| { - error!( - "Invalid history visibility event in database for room {room_id}, assuming is \"shared\": \ - {e}" - ); - Error::bad_database("Invalid history visibility event in database.") - }) - .unwrap() + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility }); let visibility = match history_visibility { @@ -284,25 +273,18 @@ impl Service { /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool { - let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; + if self.services.state_cache.is_joined(user_id, room_id).await { + return true; + } let history_visibility = self - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "") + .room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") .await - .map_or(Ok(HistoryVisibility::Shared), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|e| { - error!( - "Invalid history visibility event in database for room {room_id}, assuming is \"shared\": \ - {e}" - ); - Error::bad_database("Invalid history visibility event in database.") - }) - }) - .unwrap_or(HistoryVisibility::Shared); + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }); - currently_member || history_visibility == HistoryVisibility::WorldReadable + history_visibility == HistoryVisibility::WorldReadable } /// Returns the state hash for this pdu. From f7af6966b7fbfac3de141f7101bfe8b5e3904c85 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 09:44:43 +0000 Subject: [PATCH 060/245] refactor to room_state_get_content() for serde_json::from_ elim Signed-off-by: Jason Volk --- src/api/client/membership.rs | 78 ++++++++++--------------- src/api/client/profile.rs | 57 ++++++++---------- src/api/client/room.rs | 32 ++++------ src/service/rooms/alias/mod.rs | 27 +++++---- src/service/rooms/spaces/mod.rs | 9 +-- src/service/rooms/state_accessor/mod.rs | 23 +++----- src/service/rooms/timeline/mod.rs | 7 +-- 7 files changed, 91 insertions(+), 142 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index ae56094ce..a260b8c5c 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -389,17 +389,12 @@ pub(crate) async fn kick_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let mut event: RoomMemberEventContent = serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) - .await - .map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))? - .content - .get(), - ) - .map_err(|_| err!(Database("Invalid member event in database.")))?; + let mut event: RoomMemberEventContent = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -442,10 +437,10 @@ pub(crate) async fn ban_user_route( let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .room_state_get_content(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) .await - .map_or( - Ok(RoomMemberEventContent { + .map_or_else( + |_| RoomMemberEventContent { membership: MembershipState::Ban, displayname: None, avatar_url: None, @@ -454,21 +449,17 @@ pub(crate) async fn ban_user_route( blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, - }), - |event| { - serde_json::from_str(event.content.get()) - .map(|event: RoomMemberEventContent| RoomMemberEventContent { - membership: MembershipState::Ban, - displayname: None, - avatar_url: None, - blurhash: blurhash.clone(), - reason: body.reason.clone(), - join_authorized_via_users_server: None, - ..event - }) - .map_err(|e| err!(Database("Invalid member event in database: {e:?}"))) }, - )?; + |event| RoomMemberEventContent { + membership: MembershipState::Ban, + displayname: None, + avatar_url: None, + blurhash: blurhash.clone(), + reason: body.reason.clone(), + join_authorized_via_users_server: None, + ..event + }, + ); services .rooms @@ -503,17 +494,12 @@ pub(crate) async fn unban_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let mut event: RoomMemberEventContent = serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) - .await - .map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))? - .content - .get(), - ) - .map_err(|e| err!(Database("Invalid member event in database: {e:?}")))?; + let mut event: RoomMemberEventContent = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -1650,14 +1636,13 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, } else { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let member_event = services + let Ok(mut event) = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) - .await; - - // Fix for broken rooms - let Ok(member_event) = member_event else { + .room_state_get_content::(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + else { + // Fix for broken rooms error!("Trying to leave a room you are not a member of."); services @@ -1677,9 +1662,6 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, return Ok(()); }; - let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()) - .map_err(|e| err!(Database(error!("Invalid room member event in database: {e}"))))?; - event.membership = MembershipState::Leave; event.reason = reason; diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index 495bc8ec3..cdc047f07 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -301,10 +301,10 @@ pub async fn update_displayname( // Send a new join membership event into all joined rooms let mut joined_rooms = Vec::new(); for room_id in all_joined_rooms { - let Ok(event) = services + let Ok(content) = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) .await else { continue; @@ -315,7 +315,7 @@ pub async fn update_displayname( content: to_raw_value(&RoomMemberEventContent { displayname: displayname.clone(), join_authorized_via_users_server: None, - ..serde_json::from_str(event.content.get()).expect("Database contains invalid PDU.") + ..content }) .expect("event is valid, we just created it"), unsigned: None, @@ -354,35 +354,28 @@ pub async fn update_avatar_url( .iter() .try_stream() .and_then(|room_id: &OwnedRoomId| async move { - Ok(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - avatar_url: avatar_url.clone(), - blurhash: blurhash.clone(), - join_authorized_via_users_server: None, - ..serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) - .await - .map_err(|_| { - Error::bad_database("Tried to send avatar URL update for user not in the room.") - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - room_id, - )) + let content = services + .rooms + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await?; + + let pdu = PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + avatar_url: avatar_url.clone(), + blurhash: blurhash.clone(), + join_authorized_via_users_server: None, + ..content + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }; + + Ok((pdu, room_id)) }) .ignore_err() .collect() diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 0d8e12a20..e22ad7963 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -664,16 +664,12 @@ pub(crate) async fn upgrade_room_route( let state_lock = services.rooms.state.mutex.lock(&replacement_room).await; // Get the old room creation event - let mut create_event_content = serde_json::from_str::( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomCreate, "") - .await - .map_err(|_| err!(Database("Found room without m.room.create event.")))? - .content - .get(), - )?; + let mut create_event_content: CanonicalJsonObject = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomCreate, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))?; // Use the m.room.tombstone event as the predecessor let predecessor = Some(ruma::events::room::create::PreviousRoom::new( @@ -825,16 +821,12 @@ pub(crate) async fn upgrade_room_route( } // Get the old room power levels - let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "") - .await - .map_err(|_| err!(Database("Found room without m.room.create event.")))? - .content - .get(), - )?; + let mut power_levels_event_content: RoomPowerLevelsEventContent = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("Found room without m.room.power_levels event.")))?; // Setting events_default and invite to the greater of 50 and users_default + 1 let new_level = max( diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index f50cc46c0..7fac6be69 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -190,32 +190,31 @@ impl Service { // Always allow the server service account to remove the alias, since there may not be an admin room || server_user == user_id { - Ok(true) - // Checking whether the user is able to change canonical aliases of the - // room - } else if let Ok(event) = self + return Ok(true); + } + + // Checking whether the user is able to change canonical aliases of the room + if let Ok(content) = self .services .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "") + .room_state_get_content::(&room_id, &StateEventType::RoomPowerLevels, "") .await { - serde_json::from_str(event.content.get()) - .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) - .map(|content: RoomPowerLevelsEventContent| { - RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomCanonicalAlias) - }) + return Ok(RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomCanonicalAlias)); + } + // If there is no power levels event, only the room creator can change // canonical aliases - } else if let Ok(event) = self + if let Ok(event) = self .services .state_accessor .room_state_get(&room_id, &StateEventType::RoomCreate, "") .await { - Ok(event.sender == user_id) - } else { - Err(Error::bad_database("Room has no m.room.create event")) + return Ok(event.sender == user_id); } + + Err!(Database("Room has no m.room.create event")) } async fn who_created_alias(&self, alias: &RoomAliasId) -> Result { diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 920424a42..a30c2cfc1 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -380,14 +380,9 @@ impl Service { let join_rule = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") .await - .map_or(JoinRule::Invite, |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| c.join_rule) - .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - .unwrap() - }); + .map_or(JoinRule::Invite, |c: RoomJoinRulesEventContent| c.join_rule); let allowed_room_ids = self .services diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index ece8679d3..3b2c29313 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -338,14 +338,13 @@ impl Service { .map(|c: RoomNameEventContent| c.name) } - pub async fn get_avatar(&self, room_id: &RoomId) -> ruma::JsOption { - self.room_state_get(room_id, &StateEventType::RoomAvatar, "") + pub async fn get_avatar(&self, room_id: &RoomId) -> JsOption { + let content = self + .room_state_get_content(room_id, &StateEventType::RoomAvatar, "") .await - .map_or(ruma::JsOption::Undefined, |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) - .unwrap() - }) + .ok(); + + JsOption::from_option(content) } pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result { @@ -416,16 +415,10 @@ impl Service { &self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool, ) -> Result { if let Ok(event) = self - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .room_state_get_content::(room_id, &StateEventType::RoomPowerLevels, "") .await { - let Ok(event) = serde_json::from_str(event.content.get()) - .map(|content: RoomPowerLevelsEventContent| content.into()) - .map(|event: RoomPowerLevels| event) - else { - return Ok(false); - }; - + let event: RoomPowerLevels = event.into(); Ok(event.user_can_redact_event_of_other(sender) || event.user_can_redact_own_event(sender) && if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 7cf06522d..cc5940e6c 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1061,13 +1061,8 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") .await - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - .unwrap() - }) .unwrap_or_default(); let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { From 55c85f685177eb22f126fdd7382e99959e32e3d8 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 09:49:28 +0000 Subject: [PATCH 061/245] refactor to pdu.get_content() for serde_json::from_ elim Signed-off-by: Jason Volk --- src/api/client/sync.rs | 46 +++++++------------------ src/service/rooms/event_handler/mod.rs | 10 +++--- src/service/rooms/pdu_metadata/mod.rs | 2 +- src/service/rooms/spaces/mod.rs | 10 +++--- src/service/rooms/state/mod.rs | 17 ++------- src/service/rooms/state_accessor/mod.rs | 4 +-- src/service/rooms/timeline/mod.rs | 36 +++++-------------- 7 files changed, 35 insertions(+), 90 deletions(-) diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 65d62a786..8c4c6a445 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -635,17 +635,8 @@ async fn load_joined_room( .await? .ready_filter(|(_, pdu)| pdu.kind == RoomMember) .filter_map(|(_, pdu)| async move { - let Ok(content) = serde_json::from_str::(pdu.content.get()) else { - return None; - }; - - let Some(state_key) = &pdu.state_key else { - return None; - }; - - let Ok(user_id) = UserId::parse(state_key) else { - return None; - }; + let content: RoomMemberEventContent = pdu.get_content().ok()?; + let user_id: &UserId = pdu.state_key.as_deref().map(TryInto::try_into).flat_ok()?; if user_id == sender_user { return None; @@ -656,22 +647,17 @@ async fn load_joined_room( return None; } - if !services - .rooms - .state_cache - .is_joined(&user_id, room_id) - .await && services - .rooms - .state_cache - .is_invited(&user_id, room_id) - .await - { + let is_invited = services.rooms.state_cache.is_invited(user_id, room_id); + + let is_joined = services.rooms.state_cache.is_joined(user_id, room_id); + + if !is_joined.await && is_invited.await { return None; } - Some(user_id) + Some(user_id.to_owned()) }) - .collect::>() + .collect::>() .await; Ok::<_, Error>(( @@ -839,11 +825,9 @@ async fn load_joined_room( continue; } - let new_membership = serde_json::from_str::(state_event.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; + let content: RoomMemberEventContent = state_event.get_content()?; - match new_membership { + match content.membership { MembershipState::Join => { // A new user joined an encrypted room if !share_encrypted_room(services, sender_user, &user_id, Some(room_id)).await { @@ -1357,12 +1341,8 @@ pub(crate) async fn sync_events_v4_route( continue; } - let new_membership = - serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; - - match new_membership { + let content: RoomMemberEventContent = pdu.get_content()?; + match content.membership { MembershipState::Join => { // A new user joined an encrypted room if !share_encrypted_room(&services, sender_user, &user_id, Some(room_id)) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 4708a86cb..05f9a27a9 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -614,9 +614,7 @@ impl Service { } }, _ => { - let content = serde_json::from_str::(incoming_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; - + let content: RoomRedactionEventContent = incoming_pdu.get_content()?; if let Some(redact_id) = &content.redacts { !self .services @@ -1432,10 +1430,10 @@ impl Service { } fn get_room_version_id(create_event: &PduEvent) -> Result { - let create_event_content: RoomCreateEventContent = serde_json::from_str(create_event.content.get()) - .map_err(|e| err!(Database("Invalid create event: {e}")))?; + let content: RoomCreateEventContent = create_event.get_content()?; + let room_version = content.room_version; - Ok(create_event_content.room_version) + Ok(room_version) } #[inline] diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index dbaebfbf3..fb85d031b 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -102,7 +102,7 @@ impl Service { return false; } - let Ok(content) = serde_json::from_str::(pdu.content.get()) else { + let Ok(content) = pdu.get_content::() else { return false; }; diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index a30c2cfc1..5aea5f6a0 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -596,12 +596,10 @@ impl Service { .await .map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?; - if serde_json::from_str::(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) - { - continue; + if let Ok(content) = pdu.get_content::() { + if content.via.is_empty() { + continue; + } } if OwnedRoomId::try_from(state_key).is_ok() { diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 177b7e9b2..81760b368 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -93,28 +93,17 @@ impl Service { pin_mut!(event_ids); while let Some(event_id) = event_ids.next().await { - let Ok(pdu) = self.services.timeline.get_pdu_json(&event_id).await else { + let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { continue; }; - let pdu: PduEvent = match serde_json::from_str( - &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), - ) { - Ok(pdu) => pdu, - Err(_) => continue, - }; - match pdu.kind { TimelineEventType::RoomMember => { - let Ok(membership_event) = serde_json::from_str::(pdu.content.get()) else { - continue; - }; - - let Some(state_key) = pdu.state_key else { + let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok() else { continue; }; - let Ok(user_id) = UserId::parse(state_key) else { + let Ok(membership_event) = pdu.get_content::() else { continue; }; diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 3b2c29313..3855d92a2 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -325,11 +325,9 @@ impl Service { where T: for<'de> Deserialize<'de> + Send, { - use serde_json::from_str; - self.room_state_get(room_id, event_type, state_key) .await - .and_then(|event| from_str::(event.content.get()).map_err(Into::into)) + .and_then(|event| event.get_content()) } pub async fn get_name(&self, room_id: &RoomId) -> Result { diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index cc5940e6c..487262e68 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -471,12 +471,7 @@ impl Service { } }, _ => { - let content = - serde_json::from_str::(pdu.content.get()).map_err(|e| { - warn!("Invalid content in redaction pdu: {e}"); - Error::bad_database("Invalid content in redaction pdu") - })?; - + let content: RoomRedactionEventContent = pdu.get_content()?; if let Some(redact_id) = &content.redacts { if self .services @@ -506,11 +501,7 @@ impl Service { let target_user_id = UserId::parse(state_key.clone()).expect("This state_key was previously validated"); - let content = serde_json::from_str::(pdu.content.get()).map_err(|e| { - error!("Invalid room member event content in pdu: {e}"); - Error::bad_database("Invalid room member event content in pdu.") - })?; - + let content: RoomMemberEventContent = pdu.get_content()?; let invite_state = match content.membership { MembershipState::Invite => self.services.state.summary_stripped(pdu).await.into(), _ => None, @@ -533,9 +524,7 @@ impl Service { } }, TimelineEventType::RoomMessage => { - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - + let content: ExtractBody = pdu.get_content()?; if let Some(body) = content.body { self.services.search.index_pdu(shortroomid, &pdu_id, &body); @@ -549,7 +538,7 @@ impl Service { _ => {}, } - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Ok(content) = pdu.get_content::() { if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services .pdu_metadata @@ -557,7 +546,7 @@ impl Service { } } - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Ok(content) = pdu.get_content::() { match content.relates_to { Relation::Reply { in_reply_to, @@ -712,10 +701,7 @@ impl Service { .room_state_get(room_id, &event_type.to_string().into(), state_key) .await { - unsigned.insert( - "prev_content".to_owned(), - serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), - ); + unsigned.insert("prev_content".to_owned(), prev_pdu.get_content_as_value()); unsigned.insert( "prev_sender".to_owned(), serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), @@ -874,9 +860,7 @@ impl Service { }; }, _ => { - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|e| err!(Database("Invalid content in redaction pdu: {e:?}")))?; - + let content: RoomRedactionEventContent = pdu.get_content()?; if let Some(redact_id) = &content.redacts { if !self .services @@ -1026,7 +1010,7 @@ impl Service { .await .map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?; - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Ok(content) = pdu.get_content::() { if let Some(body) = content.body { self.services .search @@ -1200,9 +1184,7 @@ impl Service { drop(insert_lock); if pdu.kind == TimelineEventType::RoomMessage { - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|e| err!(Database("Invalid content in pdu: {e:?}")))?; - + let content: ExtractBody = pdu.get_content()?; if let Some(body) = content.body { self.services.search.index_pdu(shortroomid, &pdu_id, &body); } From d526db681f045f28519a3757f761090599d2a14e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 09:57:43 +0000 Subject: [PATCH 062/245] refactor various patterns for serde_json::from_ elim bump ruma Signed-off-by: Jason Volk --- Cargo.lock | 26 ++++++++++----------- Cargo.toml | 2 +- src/api/client/membership.rs | 24 +++++++------------- src/api/client/state.rs | 28 +++++++---------------- src/api/server/get_missing_events.rs | 24 +++++++++++--------- src/api/server/invite.rs | 31 ++++++++++---------------- src/service/rooms/event_handler/mod.rs | 16 +++++-------- src/service/rooms/state/mod.rs | 1 + src/service/sending/sender.rs | 8 ++----- 9 files changed, 64 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b9f366e79..cae6994c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2976,7 +2976,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "assign", "js_int", @@ -2998,7 +2998,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "js_int", "ruma-common", @@ -3010,7 +3010,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "as_variant", "assign", @@ -3033,7 +3033,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "as_variant", "base64 0.22.1", @@ -3063,7 +3063,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3087,7 +3087,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "bytes", "http", @@ -3105,7 +3105,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "js_int", "thiserror", @@ -3114,7 +3114,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "js_int", "ruma-common", @@ -3124,7 +3124,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "cfg-if", "once_cell", @@ -3140,7 +3140,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "js_int", "ruma-common", @@ -3152,7 +3152,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "headers", "http", @@ -3165,7 +3165,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3181,7 +3181,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37#e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" dependencies = [ "futures-util", "itertools 0.13.0", diff --git a/Cargo.toml b/Cargo.toml index 18f33375f..25d1001da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -315,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "e81ed2741b4ebe98fe41cabdfee2ac28a52a8e37" +rev = "f485a0265c67a59df75fc6686787538172fa4cac" features = [ "compat", "rand", diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index a260b8c5c..fde6099a4 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -183,10 +183,8 @@ pub(crate) async fn join_room_by_id_route( .await .unwrap_or_default() .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), ); @@ -248,10 +246,8 @@ pub(crate) async fn join_room_by_id_or_alias_route( .await .unwrap_or_default() .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), ); @@ -294,10 +290,8 @@ pub(crate) async fn join_room_by_id_or_alias_route( .await .unwrap_or_default() .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), ); @@ -1708,10 +1702,8 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room servers.extend( invite_state .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), ); diff --git a/src/api/client/state.rs b/src/api/client/state.rs index d89c23e8c..2a13ba1f4 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{err, error, pdu::PduBuilder, Err, Error, Result}; +use conduit::{err, pdu::PduBuilder, utils::BoolExt, Err, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, @@ -137,27 +137,15 @@ pub(crate) async fn get_state_events_for_key_route( )))) })?; - if body + let event_format = body .format .as_ref() - .is_some_and(|f| f.to_lowercase().eq("event")) - { - Ok(get_state_events_for_key::v3::Response { - content: None, - event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { - error!("Invalid room state event in database: {}", e); - Error::bad_database("Invalid room state event in database") - })?, - }) - } else { - Ok(get_state_events_for_key::v3::Response { - content: Some(serde_json::from_str(event.content.get()).map_err(|e| { - error!("Invalid room state event content in database: {}", e); - Error::bad_database("Invalid room state event content in database") - })?), - event: None, - }) - } + .is_some_and(|f| f.to_lowercase().eq("event")); + + Ok(get_state_events_for_key::v3::Response { + content: event_format.or(|| event.get_content_as_value()), + event: event_format.then(|| event.to_state_event_value()), + }) } /// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}` diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index 7ae0ff608..e267898fe 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -2,7 +2,7 @@ use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_missing_events}, - OwnedEventId, RoomId, + CanonicalJsonValue, EventId, RoomId, }; use crate::Ruma; @@ -78,17 +78,19 @@ pub(crate) async fn get_missing_events_route( continue; } - queued_events.extend_from_slice( - &serde_json::from_value::>( - serde_json::to_value( - pdu.get("prev_events") - .cloned() - .ok_or_else(|| Error::bad_database("Event in db has no prev_events property."))?, - ) - .expect("canonical json is valid json value"), - ) - .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, + let prev_events = pdu + .get("prev_events") + .and_then(CanonicalJsonValue::as_array) + .unwrap_or_default(); + + queued_events.extend( + prev_events + .iter() + .map(<&EventId>::try_from) + .filter_map(Result::ok) + .map(ToOwned::to_owned), ); + events.push( services .sending diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 9968bdf72..dd2374b6d 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -1,11 +1,11 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{utils, warn, Error, PduEvent, Result}; +use conduit::{err, utils, warn, Err, Error, PduEvent, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_invite}, events::room::member::{MembershipState, RoomMemberEventContent}, serde::JsonObject, - CanonicalJsonValue, EventId, OwnedUserId, + CanonicalJsonValue, EventId, OwnedUserId, UserId, }; use crate::Ruma; @@ -79,14 +79,11 @@ pub(crate) async fn create_invite_route( let mut signed_event = utils::to_canonical_object(&body.event) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; - let invited_user: OwnedUserId = serde_json::from_value( - signed_event - .get("state_key") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event has no state_key property."))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user ID."))?; + let invited_user: OwnedUserId = signed_event + .get("state_key") + .try_into() + .map(UserId::to_owned) + .map_err(|e| err!(Request(InvalidParam("Invalid state_key property: {e}"))))?; if !services.globals.server_is_ours(invited_user.server_name()) { return Err(Error::BadRequest( @@ -121,14 +118,10 @@ pub(crate) async fn create_invite_route( // Add event_id back signed_event.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.to_string())); - let sender: OwnedUserId = serde_json::from_value( - signed_event - .get("sender") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event had no sender property."))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; + let sender: &UserId = signed_event + .get("sender") + .try_into() + .map_err(|e| err!(Request(InvalidParam("Invalid sender property: {e}"))))?; if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await { return Err(Error::BadRequest( @@ -171,7 +164,7 @@ pub(crate) async fn create_invite_route( &body.room_id, &invited_user, RoomMemberEventContent::new(MembershipState::Invite), - &sender, + sender, Some(invite_state), body.via.clone(), true, diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 05f9a27a9..f8042b67b 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -30,8 +30,8 @@ use ruma::{ int, serde::Base64, state_res::{self, EventTypeExt, RoomVersion, StateMap}, - uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, - RoomVersionId, ServerName, + uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, + ServerName, UserId, }; use tokio::sync::RwLock; @@ -157,14 +157,10 @@ impl Service { self.acl_check(origin, room_id).await?; // 1.3.2 Check room ACL on sender's server name - let sender: OwnedUserId = serde_json::from_value( - value - .get("sender") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "PDU does not have a sender key"))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid"))?; + let sender: &UserId = value + .get("sender") + .try_into() + .map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?; self.acl_check(sender.server_name(), room_id).await?; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 81760b368..cfcb2da6f 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -9,6 +9,7 @@ use std::{ use conduit::{ err, + result::FlatOk, utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard}, warn, PduEvent, Result, }; diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 90977abe9..5c0a324bc 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -528,12 +528,8 @@ impl Service { for pdu in pdus { // Redacted events are not notification targets (we don't send push for them) - if let Some(unsigned) = &pdu.unsigned { - if let Ok(unsigned) = serde_json::from_str::(unsigned.get()) { - if unsigned.get("redacted_because").is_some() { - continue; - } - } + if pdu.contains_unsigned_property("redacted_because", serde_json::Value::is_string) { + continue; } let rules_for_user = self From 57e0a5f65dce2be514d0bc45dbfb26b5c5b0cd00 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 10:02:24 +0000 Subject: [PATCH 063/245] additional database stream deserializations for serde_json::from_ elim Signed-off-by: Jason Volk --- src/service/key_backups/mod.rs | 64 ++++++++------------------- src/service/pusher/mod.rs | 2 +- src/service/rooms/state_cache/data.rs | 64 +++++++++++++-------------- src/service/users/mod.rs | 11 ++--- 4 files changed, 58 insertions(+), 83 deletions(-) diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index decf32f7f..55263eeb1 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,9 +1,9 @@ use std::{collections::BTreeMap, sync::Arc}; use conduit::{ - err, implement, utils, + err, implement, utils::stream::{ReadyExt, TryIgnore}, - Err, Error, Result, + Err, Result, }; use database::{Deserialized, Ignore, Interfix, Map}; use futures::StreamExt; @@ -110,57 +110,35 @@ pub async fn update_backup( #[implement(Service)] pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + type Key<'a> = (&'a UserId, &'a str); + let last_possible_key = (user_id, u64::MAX); self.db .backupid_algorithm - .rev_raw_keys_from(&last_possible_key) + .rev_keys_from(&last_possible_key) .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix)) + .ready_take_while(|(user_id_, _): &Key<'_>| *user_id_ == user_id) + .map(|(_, version): Key<'_>| version.to_owned()) .next() .await .ok_or_else(|| err!(Request(NotFound("No backup versions found")))) - .and_then(|key| { - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) - }) } #[implement(Service)] pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + type Key<'a> = (&'a UserId, &'a str); + type KeyVal<'a> = (Key<'a>, Raw); + let last_possible_key = (user_id, u64::MAX); self.db .backupid_algorithm - .rev_raw_stream_from(&last_possible_key) + .rev_stream_from(&last_possible_key) .ignore_err() - .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .ready_take_while(|((user_id_, _), _): &KeyVal<'_>| *user_id_ == user_id) + .map(|((_, version), algorithm): KeyVal<'_>| (version.to_owned(), algorithm)) .next() .await .ok_or_else(|| err!(Request(NotFound("No backup found")))) - .and_then(|(key, val)| { - let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; - - let algorithm = serde_json::from_slice(val) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?; - - Ok((version, algorithm)) - }) } #[implement(Service)] @@ -223,7 +201,8 @@ pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String { #[implement(Service)] pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap { - type KeyVal<'a> = ((Ignore, Ignore, &'a RoomId, &'a str), &'a [u8]); + type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str); + type KeyVal<'a> = (Key<'a>, Raw); let mut rooms = BTreeMap::::new(); let default = || RoomKeyBackup { @@ -235,13 +214,12 @@ pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap| { - let key_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); + .ready_for_each(|((_, _, room_id, session_id), key_backup_data): KeyVal<'_>| { rooms .entry(room_id.into()) .or_insert_with(default) .sessions - .insert(session_id.into(), key_data); + .insert(session_id.into(), key_backup_data); }) .await; @@ -252,18 +230,14 @@ pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap BTreeMap> { - type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), &'a [u8]); + type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), Raw); let prefix = (user_id, version, room_id, Interfix); self.db .backupkeyid_backup .stream_prefix(&prefix) .ignore_err() - .map(|((.., session_id), value): KeyVal<'_>| { - let session_id = session_id.to_owned(); - let key_backup_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); - (session_id, key_backup_data) - }) + .map(|((.., session_id), key_backup_data): KeyVal<'_>| (session_id.to_owned(), key_backup_data)) .collect() .await } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 8d8b553fe..e7b1824ad 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -99,7 +99,7 @@ impl Service { .senderkey_pusher .stream_prefix(&prefix) .ignore_err() - .map(|(_, val): (Ignore, &[u8])| serde_json::from_slice(val).expect("Invalid Pusher in db.")) + .map(|(_, pusher): (Ignore, Pusher)| pusher) .collect() .await } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index f3ccaf102..6e01e49df 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, RwLock}, }; -use conduit::{utils, utils::stream::TryIgnore, Error, Result}; +use conduit::{utils::stream::TryIgnore, Result}; use database::{Deserialized, Interfix, Map}; use futures::{Stream, StreamExt}; use ruma::{ @@ -135,20 +135,31 @@ impl Data { pub(super) fn rooms_invited<'a>( &'a self, user_id: &'a UserId, ) -> impl Stream + Send + 'a { + type Key<'a> = (&'a UserId, &'a RoomId); + type KeyVal<'a> = (Key<'a>, Raw>); + let prefix = (user_id, Interfix); self.userroomid_invitestate - .stream_raw_prefix(&prefix) + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() + } + + /// Returns an iterator over all rooms a user left. + #[inline] + pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + type Key<'a> = (&'a UserId, &'a RoomId); + type KeyVal<'a> = (Key<'a>, Raw>>); + + let prefix = (user_id, Interfix); + self.userroomid_leftstate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) .ignore_err() - .map(|(key, val)| { - let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); - let room_id = utils::string_from_bytes(room_id).unwrap(); - let room_id = RoomId::parse(room_id).unwrap(); - let state = serde_json::from_slice(val) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate.")) - .unwrap(); - - (room_id, state) - }) } #[tracing::instrument(skip(self), level = "debug")] @@ -156,7 +167,11 @@ impl Data { &self, user_id: &UserId, room_id: &RoomId, ) -> Result>> { let key = (user_id, room_id); - self.userroomid_invitestate.qry(&key).await.deserialized() + self.userroomid_invitestate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) } #[tracing::instrument(skip(self), level = "debug")] @@ -164,25 +179,10 @@ impl Data { &self, user_id: &UserId, room_id: &RoomId, ) -> Result>> { let key = (user_id, room_id); - self.userroomid_leftstate.qry(&key).await.deserialized() - } - - /// Returns an iterator over all rooms a user left. - #[inline] - pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { - let prefix = (user_id, Interfix); self.userroomid_leftstate - .stream_raw_prefix(&prefix) - .ignore_err() - .map(|(key, val)| { - let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); - let room_id = utils::string_from_bytes(room_id).unwrap(); - let room_id = RoomId::parse(room_id).unwrap(); - let state = serde_json::from_slice(val) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate.")) - .unwrap(); - - (room_id, state) - }) + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) } } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 3ab6b3c33..71a93666f 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc}; use conduit::{ debug_warn, err, utils, - utils::{stream::TryIgnore, string::Unquoted, ReadyExt, TryReadyExt}, + utils::{stream::TryIgnore, string::Unquoted, ReadyExt}, warn, Err, Error, Result, Server, }; use database::{Deserialized, Ignore, Interfix, Map}; @@ -749,9 +749,9 @@ impl Service { let prefix = (user_id, device_id, Interfix); self.db .todeviceid_events - .stream_raw_prefix(&prefix) - .ready_and_then(|(_, val)| serde_json::from_slice(val).map_err(Into::into)) + .stream_prefix(&prefix) .ignore_err() + .map(|(_, val): (Ignore, Raw)| val) } pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) { @@ -812,11 +812,12 @@ impl Service { } pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + let key = (user_id, Interfix); self.db .userdeviceid_metadata - .stream_raw_prefix(&(user_id, Interfix)) - .ready_and_then(|(_, val)| serde_json::from_slice::(val).map_err(Into::into)) + .stream_prefix(&key) .ignore_err() + .map(|(_, val): (Ignore, Device)| val) } /// Creates a new sync filter. Returns the filter id. From f503ed918c90720c28f978c2851d252e21920a29 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 3 Oct 2024 10:03:31 +0000 Subject: [PATCH 064/245] misc cleanup Signed-off-by: Jason Volk --- src/api/client/membership.rs | 33 ++++++++----------- src/api/client/state.rs | 2 +- src/api/server/invite.rs | 28 ++++------------ src/api/server/send.rs | 22 ++++++------- src/service/admin/mod.rs | 2 +- .../rooms/event_handler/parse_incoming_pdu.rs | 30 ++++++++--------- src/service/rooms/timeline/mod.rs | 3 +- 7 files changed, 49 insertions(+), 71 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index fde6099a4..f21f3d7d0 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1333,10 +1333,8 @@ pub async fn validate_and_add_event_id( services: &Services, pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock>>, ) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - debug_error!("Invalid PDU in server response: {pdu:#?}"); - err!(BadServerResponse("Invalid PDU in server response: {e:?}")) - })?; + let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()) + .map_err(|e| err!(BadServerResponse(debug_error!("Invalid PDU in server response: {e:?}"))))?; let event_id = EventId::parse(format!( "${}", ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes") @@ -1478,10 +1476,8 @@ pub(crate) async fn invite_helper( if *pdu.event_id != *event_id { warn!( - "Server {} changed invite event, that's not allowed in the spec: ours: {:?}, theirs: {:?}", + "Server {} changed invite event, that's not allowed in the spec: ours: {pdu_json:?}, theirs: {value:?}", user_id.server_name(), - pdu_json, - value ); } @@ -1564,20 +1560,19 @@ pub(crate) async fn invite_helper( // Make a user leave all their joined rooms, forgets all rooms, and ignores // errors pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { - let all_rooms: Vec<_> = services + let rooms_joined = services .rooms .state_cache .rooms_joined(user_id) - .map(ToOwned::to_owned) - .chain( - services - .rooms - .state_cache - .rooms_invited(user_id) - .map(|(r, _)| r), - ) - .collect() - .await; + .map(ToOwned::to_owned); + + let rooms_invited = services + .rooms + .state_cache + .rooms_invited(user_id) + .map(|(r, _)| r); + + let all_rooms: Vec<_> = rooms_joined.chain(rooms_invited).collect().await; for room_id in all_rooms { // ignore errors @@ -1601,7 +1596,7 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, .await { if let Err(e) = remote_leave_room(services, user_id, room_id).await { - warn!("Failed to leave room {} remotely: {}", user_id, e); + warn!("Failed to leave room {user_id} remotely: {e}"); // Don't tell the client about this error } diff --git a/src/api/client/state.rs b/src/api/client/state.rs index 2a13ba1f4..1396ae778 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -176,7 +176,7 @@ async fn send_state_event_for_key_helper( .build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), - content: serde_json::from_str(json.json().get()).expect("content is valid json"), + content: serde_json::from_str(json.json().get())?, unsigned: None, state_key: Some(state_key), redacts: None, diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index dd2374b6d..447e54be0 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -47,10 +47,7 @@ pub(crate) async fn create_invite_route( .forbidden_remote_server_names .contains(&server.to_owned()) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } @@ -64,15 +61,13 @@ pub(crate) async fn create_invite_route( "Received federated/remote invite from banned server {origin} for room ID {}. Rejecting.", body.room_id ); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } if let Some(via) = &body.via { if via.is_empty() { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "via field must not be empty.")); + return Err!(Request(InvalidParam("via field must not be empty."))); } } @@ -86,10 +81,7 @@ pub(crate) async fn create_invite_route( .map_err(|e| err!(Request(InvalidParam("Invalid state_key property: {e}"))))?; if !services.globals.server_is_ours(invited_user.server_name()) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not belong to this homeserver.", - )); + return Err!(Request(InvalidParam("User does not belong to this homeserver."))); } // Make sure we're not ACL'ed from their room. @@ -124,17 +116,11 @@ pub(crate) async fn create_invite_route( .map_err(|e| err!(Request(InvalidParam("Invalid sender property: {e}"))))?; if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This room is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This room is banned on this homeserver."))); } if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user).await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This server does not allow room invites.", - )); + return Err!(Request(Forbidden("This server does not allow room invites."))); } let mut invite_state = body.invite_room_state.clone(); diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 50a79e002..f6916ccfa 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -16,9 +16,11 @@ use ruma::{ }, }, events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, + serde::Raw, to_device::DeviceIdOrAllDevices, OwnedEventId, ServerName, }; +use serde_json::value::RawValue as RawJsonValue; use tokio::sync::RwLock; use crate::{ @@ -70,8 +72,8 @@ pub(crate) async fn send_transaction_message_route( "Starting txn", ); - let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await; - handle_edus(&services, &client, &body, origin).await; + let resolved_map = handle_pdus(&services, &client, &body.pdus, origin, &txn_start_time).await; + handle_edus(&services, &client, &body.edus, origin).await; debug!( pdus = ?body.pdus.len(), @@ -91,11 +93,10 @@ pub(crate) async fn send_transaction_message_route( } async fn handle_pdus( - services: &Services, _client: &IpAddr, body: &Ruma, origin: &ServerName, - txn_start_time: &Instant, + services: &Services, _client: &IpAddr, pdus: &[Box], origin: &ServerName, txn_start_time: &Instant, ) -> ResolvedMap { - let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); - for pdu in &body.pdus { + let mut parsed_pdus = Vec::with_capacity(pdus.len()); + for pdu in pdus { parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { Ok(t) => t, Err(e) => { @@ -162,11 +163,8 @@ async fn handle_pdus( resolved_map } -async fn handle_edus( - services: &Services, client: &IpAddr, body: &Ruma, origin: &ServerName, -) { - for edu in body - .edus +async fn handle_edus(services: &Services, client: &IpAddr, edus: &[Raw], origin: &ServerName) { + for edu in edus .iter() .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) { @@ -178,7 +176,7 @@ async fn handle_edus( Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await, Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await, Edu::_Custom(ref _custom) => { - debug_warn!(?body.edus, "received custom/unknown EDU"); + debug_warn!(?edus, "received custom/unknown EDU"); }, } } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 12eacc8fa..da7f3cf4c 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -198,7 +198,6 @@ impl Service { Ok(None) => debug!("Command successful with no response"), Ok(Some(output)) | Err(output) => self .handle_response(output) - .boxed() .await .unwrap_or_else(default_log), } @@ -277,6 +276,7 @@ impl Service { }; self.respond_to_room(content, &pdu.room_id, response_sender) + .boxed() .await } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 9081fcbca..39920219a 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,29 +1,29 @@ -use conduit::{debug_warn, err, pdu::gen_event_id_canonical_json, Err, Result}; -use ruma::{CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; +use conduit::{err, pdu::gen_event_id_canonical_json, result::FlatOk, Result}; +use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; impl super::Service { pub async fn parse_incoming_pdu( &self, pdu: &RawJsonValue, ) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - debug_warn!("Error parsing incoming event {pdu:#?}"); - err!(BadServerResponse("Error parsing incoming event {e:?}")) - })?; + let value = serde_json::from_str::(pdu.get()) + .map_err(|e| err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))))?; let room_id: OwnedRoomId = value .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or_else(|| err!(Request(InvalidParam("Invalid room id in pdu"))))?; + .and_then(CanonicalJsonValue::as_str) + .map(RoomId::parse) + .flat_ok_or(err!(Request(InvalidParam("Invalid room_id in pdu"))))?; - let Ok(room_version_id) = self.services.state.get_room_version(&room_id).await else { - return Err!("Server is not in room {room_id}"); - }; + let room_version_id = self + .services + .state + .get_room_version(&room_id) + .await + .map_err(|_| err!("Server is not in room {room_id}"))?; - let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { - // Event could not be converted to canonical json - return Err!(Request(InvalidParam("Could not convert event to canonical json."))); - }; + let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id) + .map_err(|e| err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))))?; Ok((event_id, value, room_id)) } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 487262e68..21e5395da 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -661,8 +661,7 @@ impl Service { .await .or_else(|_| { if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::(content.get()) - .expect("Invalid content in RoomCreate pdu."); + let content: RoomCreateEventContent = serde_json::from_str(content.get())?; Ok(content.room_version) } else { Err(Error::InconsistentRoomState( From e482c0646f58ae0fe58abc12dff4be7cb1fd8e8f Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 4 Oct 2024 20:25:32 +0000 Subject: [PATCH 065/245] Add constructions and Default for PduBuilder simplify various RoomMemberEventContent constructions Signed-off-by: Jason Volk --- src/admin/user/commands.rs | 21 +-- src/api/client/account.rs | 12 +- src/api/client/membership.rs | 157 +++++++--------------- src/api/client/message.rs | 17 +-- src/api/client/profile.rs | 33 ++--- src/api/client/redact.rs | 18 +-- src/api/client/room.rs | 159 +++++++--------------- src/api/client/state.rs | 3 +- src/api/server/make_join.rs | 29 ++-- src/api/server/make_leave.rs | 25 +--- src/core/pdu/builder.rs | 59 +++++++- src/core/pdu/mod.rs | 5 +- src/service/admin/create.rs | 171 +++++++----------------- src/service/admin/grant.rs | 83 +++--------- src/service/admin/mod.rs | 31 +---- src/service/rooms/state_accessor/mod.rs | 20 +-- 16 files changed, 279 insertions(+), 564 deletions(-) diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 562bb9c74..df3938331 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -12,11 +12,10 @@ use ruma::{ redaction::RoomRedactionEventContent, }, tag::{TagEvent, TagEventContent, TagInfo}, - RoomAccountDataEventType, StateEventType, TimelineEventType, + RoomAccountDataEventType, StateEventType, }, EventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, }; -use serde_json::value::to_raw_value; use crate::{ admin_command, get_room_info, @@ -461,14 +460,7 @@ pub(super) async fn force_demote( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &power_levels_content), &user_id, &room_id, &state_lock, @@ -623,16 +615,11 @@ pub(super) async fn redact_event(&self, event_id: Box) -> Result(room_id, &StateEventType::RoomMember, user_id.as_str()) @@ -1651,21 +1601,18 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, return Ok(()); }; - event.membership = MembershipState::Leave; - event.reason = reason; - services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + user_id.to_string(), + &RoomMemberEventContent { + membership: MembershipState::Leave, + reason, + ..event + }, + ), user_id, room_id, &state_lock, diff --git a/src/api/client/message.rs b/src/api/client/message.rs index d577e3c83..578b675b5 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -9,7 +9,6 @@ use conduit::{ use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ - error::ErrorKind, filter::{RoomEventFilter, UrlFilter}, message::{get_message_events, send_message_event}, }, @@ -21,7 +20,7 @@ use service::rooms::timeline::PdusIterItem; use crate::{ service::{pdu::PduBuilder, Services}, - utils, Error, Result, Ruma, + utils, Result, Ruma, }; /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` @@ -77,27 +76,25 @@ pub(crate) async fn send_message_event_route( let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - let content = from_str(body.body.body.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?; + let content = + from_str(body.body.body.json().get()).map_err(|e| err!(Request(BadJson("Invalid JSON body: {e}"))))?; let event_id = services .rooms .timeline .build_and_append_pdu( PduBuilder { - event_type: body.event_type.to_string().into(), + event_type: body.event_type.clone().into(), content, unsigned: Some(unsigned), - state_key: None, - redacts: None, timestamp: appservice_info.and(body.timestamp), + ..Default::default() }, sender_user, &body.room_id, &state_lock, ) - .await - .map(|event_id| (*event_id).to_owned())?; + .await?; services .transaction_ids @@ -106,7 +103,7 @@ pub(crate) async fn send_message_event_route( drop(state_lock); Ok(send_message_event::v3::Response { - event_id, + event_id: event_id.into(), }) } diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index cdc047f07..32f7a7236 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -13,11 +13,10 @@ use ruma::{ }, federation, }, - events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType}, + events::{room::member::RoomMemberEventContent, StateEventType}, presence::PresenceState, OwnedMxcUri, OwnedRoomId, UserId, }; -use serde_json::value::to_raw_value; use service::Services; use crate::Ruma; @@ -310,19 +309,14 @@ pub async fn update_displayname( continue; }; - let pdu = PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { + let pdu = PduBuilder::state( + user_id.to_string(), + &RoomMemberEventContent { displayname: displayname.clone(), join_authorized_via_users_server: None, ..content - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }; + }, + ); joined_rooms.push((pdu, room_id)); } @@ -360,20 +354,15 @@ pub async fn update_avatar_url( .room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) .await?; - let pdu = PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { + let pdu = PduBuilder::state( + user_id.to_string(), + &RoomMemberEventContent { avatar_url: avatar_url.clone(), blurhash: blurhash.clone(), join_authorized_via_users_server: None, ..content - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }; + }, + ); Ok((pdu, room_id)) }) diff --git a/src/api/client/redact.rs b/src/api/client/redact.rs index 2102f6cd5..a986dc18b 100644 --- a/src/api/client/redact.rs +++ b/src/api/client/redact.rs @@ -1,9 +1,5 @@ use axum::extract::State; -use ruma::{ - api::client::redact::redact_event, - events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, -}; -use serde_json::value::to_raw_value; +use ruma::{api::client::redact::redact_event, events::room::redaction::RoomRedactionEventContent}; use crate::{service::pdu::PduBuilder, Result, Ruma}; @@ -25,16 +21,11 @@ pub(crate) async fn redact_event_route( .timeline .build_and_append_pdu( PduBuilder { - event_type: TimelineEventType::RoomRedaction, - content: to_raw_value(&RoomRedactionEventContent { + redacts: Some(body.event_id.clone().into()), + ..PduBuilder::timeline(&RoomRedactionEventContent { redacts: Some(body.event_id.clone()), reason: body.reason.clone(), }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: Some(body.event_id.into()), - timestamp: None, }, sender_user, &body.room_id, @@ -44,8 +35,7 @@ pub(crate) async fn redact_event_route( drop(state_lock); - let event_id = (*event_id).to_owned(); Ok(redact_event::v3::Response { - event_id, + event_id: event_id.into(), }) } diff --git a/src/api/client/room.rs b/src/api/client/room.rs index e22ad7963..daadb7242 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -150,8 +150,7 @@ pub(crate) async fn create_room_route( None => services.globals.default_room_version(), }; - #[allow(clippy::single_match_else)] - let content = match &body.creation_content { + let create_content = match &body.creation_content { Some(content) => { use RoomVersionId::*; @@ -213,11 +212,9 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, + content: to_raw_value(&create_content).expect("create event content serialization"), state_key: Some(String::new()), - redacts: None, - timestamp: None, + ..Default::default() }, sender_user, &room_id, @@ -231,24 +228,16 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, + PduBuilder::state( + sender_user.to_string(), + &RoomMemberEventContent { displayname: services.users.displayname(sender_user).await.ok(), avatar_url: services.users.avatar_url(sender_user).await.ok(), - is_direct: Some(body.is_direct), - third_party_invite: None, blurhash: services.users.blurhash(sender_user).await.ok(), - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - timestamp: None, - }, + is_direct: Some(body.is_direct), + ..RoomMemberEventContent::new(MembershipState::Join) + }, + ), sender_user, &room_id, &state_lock, @@ -289,11 +278,9 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_content).expect("to_raw_value always works on serde_json::Value"), - unsigned: None, + content: to_raw_value(&power_levels_content).expect("serialized power_levels event content"), state_key: Some(String::new()), - redacts: None, - timestamp: None, + ..Default::default() }, sender_user, &room_id, @@ -308,18 +295,13 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { + PduBuilder::state( + String::new(), + &RoomCanonicalAliasEventContent { alias: Some(room_alias_id.to_owned()), alt_aliases: vec![], - }) - .expect("We checked that alias earlier, it must be fine"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), sender_user, &room_id, &state_lock, @@ -335,19 +317,14 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { + PduBuilder::state( + String::new(), + &RoomJoinRulesEventContent::new(match preset { RoomPreset::PublicChat => JoinRule::Public, // according to spec "invite" is the default _ => JoinRule::Invite, - })) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }), + ), sender_user, &room_id, &state_lock, @@ -360,15 +337,10 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + String::new(), + &RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared), + ), sender_user, &room_id, &state_lock, @@ -381,18 +353,13 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { + PduBuilder::state( + String::new(), + &RoomGuestAccessEventContent::new(match preset { RoomPreset::PublicChat => GuestAccess::Forbidden, _ => GuestAccess::CanJoin, - })) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }), + ), sender_user, &room_id, &state_lock, @@ -440,15 +407,7 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(name.clone())) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomNameEventContent::new(name.clone())), sender_user, &room_id, &state_lock, @@ -462,17 +421,12 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { + PduBuilder::state( + String::new(), + &RoomTopicEventContent { topic: topic.clone(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), sender_user, &room_id, &state_lock, @@ -641,18 +595,13 @@ pub(crate) async fn upgrade_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTombstone, - content: to_raw_value(&RoomTombstoneEventContent { + PduBuilder::state( + String::new(), + &RoomTombstoneEventContent { body: "This room has been replaced".to_owned(), replacement_room: replacement_room.clone(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), sender_user, &body.room_id, &state_lock, @@ -788,10 +737,8 @@ pub(crate) async fn upgrade_room_route( PduBuilder { event_type: event_type.to_string().into(), content: event_content, - unsigned: None, state_key: Some(String::new()), - redacts: None, - timestamp: None, + ..Default::default() }, sender_user, &replacement_room, @@ -821,7 +768,7 @@ pub(crate) async fn upgrade_room_route( } // Get the old room power levels - let mut power_levels_event_content: RoomPowerLevelsEventContent = services + let power_levels_event_content: RoomPowerLevelsEventContent = services .rooms .state_accessor .room_state_get_content(&body.room_id, &StateEventType::RoomPowerLevels, "") @@ -836,8 +783,6 @@ pub(crate) async fn upgrade_room_route( .checked_add(int!(1)) .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?, ); - power_levels_event_content.events_default = new_level; - power_levels_event_content.invite = new_level; // Modify the power levels in the old room to prevent sending of events and // inviting new users @@ -845,14 +790,14 @@ pub(crate) async fn upgrade_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_event_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + String::new(), + &RoomPowerLevelsEventContent { + events_default: new_level, + invite: new_level, + ..power_levels_event_content + }, + ), sender_user, &body.room_id, &state_lock, diff --git a/src/api/client/state.rs b/src/api/client/state.rs index 1396ae778..5090d5575 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -177,10 +177,9 @@ async fn send_state_event_for_key_helper( PduBuilder { event_type: event_type.to_string().into(), content: serde_json::from_str(json.json().get())?, - unsigned: None, state_key: Some(state_key), - redacts: None, timestamp, + ..Default::default() }, sender, room_id, diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index ba081aade..856680382 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -8,7 +8,7 @@ use ruma::{ join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, }, - StateEventType, TimelineEventType, + StateEventType, }, CanonicalJsonObject, RoomId, RoomVersionId, UserId, }; @@ -125,30 +125,17 @@ pub(crate) async fn create_join_event_template_route( )); } - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: None, - blurhash: None, - displayname: None, - is_direct: None, - membership: MembershipState::Join, - third_party_invite: None, - reason: None, - join_authorized_via_users_server, - }) - .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services .rooms .timeline .create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + body.user_id.to_string(), + &RoomMemberEventContent { + join_authorized_via_users_server, + ..RoomMemberEventContent::new(MembershipState::Join) + }, + ), &body.user_id, &body.room_id, &state_lock, diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 41ea1c80d..81a32c865 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -2,10 +2,7 @@ use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_leave_event}, - events::{ - room::member::{MembershipState, RoomMemberEventContent}, - TimelineEventType, - }, + events::room::member::{MembershipState, RoomMemberEventContent}, }; use serde_json::value::to_raw_value; @@ -39,30 +36,12 @@ pub(crate) async fn create_leave_event_template_route( let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: None, - blurhash: None, - displayname: None, - is_direct: None, - membership: MembershipState::Leave, - third_party_invite: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("member event is valid value"); let (_pdu, mut pdu_json) = services .rooms .timeline .create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Leave)), &body.user_id, &body.room_id, &state_lock, diff --git a/src/core/pdu/builder.rs b/src/core/pdu/builder.rs index ba4c19e57..80ff07130 100644 --- a/src/core/pdu/builder.rs +++ b/src/core/pdu/builder.rs @@ -1,20 +1,67 @@ use std::{collections::BTreeMap, sync::Arc}; -use ruma::{events::TimelineEventType, EventId, MilliSecondsSinceUnixEpoch}; +use ruma::{ + events::{EventContent, MessageLikeEventType, StateEventType, TimelineEventType}, + EventId, MilliSecondsSinceUnixEpoch, +}; use serde::Deserialize; -use serde_json::value::RawValue as RawJsonValue; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; /// Build the start of a PDU in order to add it to the Database. #[derive(Debug, Deserialize)] -pub struct PduBuilder { +pub struct Builder { #[serde(rename = "type")] pub event_type: TimelineEventType, + pub content: Box, - pub unsigned: Option>, + + pub unsigned: Option, + pub state_key: Option, + pub redacts: Option>, - /// For timestamped messaging, should only be used for appservices - /// + + /// For timestamped messaging, should only be used for appservices. /// Will be set to current time if None pub timestamp: Option, } + +type Unsigned = BTreeMap; + +impl Builder { + pub fn state(state_key: String, content: &T) -> Self + where + T: EventContent, + { + Self { + event_type: content.event_type().into(), + content: to_raw_value(content).expect("Builder failed to serialize state event content to RawValue"), + state_key: Some(state_key), + ..Self::default() + } + } + + pub fn timeline(content: &T) -> Self + where + T: EventContent, + { + Self { + event_type: content.event_type().into(), + content: to_raw_value(content).expect("Builder failed to serialize timeline event content to RawValue"), + ..Self::default() + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self { + event_type: "m.room.message".into(), + content: Box::::default(), + unsigned: None, + state_key: None, + redacts: None, + timestamp: None, + } + } +} diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index a94e2bdc6..5f50fe5b1 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -21,7 +21,10 @@ use serde_json::{ value::{to_raw_value, RawValue as RawJsonValue, Value as JsonValue}, }; -pub use self::{builder::PduBuilder, count::PduCount}; +pub use self::{ + builder::{Builder, Builder as PduBuilder}, + count::PduCount, +}; use crate::{err, is_true, warn, Error, Result}; #[derive(Deserialize)] diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index 3dd5aea35..1631f1cbb 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -2,24 +2,20 @@ use std::collections::BTreeMap; use conduit::{pdu::PduBuilder, Result}; use ruma::{ - events::{ - room::{ - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - name::RoomNameEventContent, - power_levels::RoomPowerLevelsEventContent, - preview_url::RoomPreviewUrlsEventContent, - topic::RoomTopicEventContent, - }, - TimelineEventType, + events::room::{ + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + name::RoomNameEventContent, + power_levels::RoomPowerLevelsEventContent, + preview_url::RoomPreviewUrlsEventContent, + topic::RoomTopicEventContent, }, RoomId, RoomVersionId, }; -use serde_json::value::to_raw_value; use crate::Services; @@ -44,7 +40,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { let room_version = services.globals.default_room_version(); - let mut content = { + let create_content = { use RoomVersionId::*; match room_version { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(server_user.clone()), @@ -52,23 +48,20 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { } }; - content.federate = true; - content.predecessor = None; - content.room_version = room_version; - // 1. The room create event services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + String::new(), + &RoomCreateEventContent { + federate: true, + predecessor: None, + room_version, + ..create_content + }, + ), server_user, &room_id, &state_lock, @@ -80,24 +73,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(server_user.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(server_user.to_string(), &RoomMemberEventContent::new(MembershipState::Join)), server_user, &room_id, &state_lock, @@ -111,18 +87,13 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { + PduBuilder::state( + String::new(), + &RoomPowerLevelsEventContent { users, ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, @@ -134,15 +105,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomJoinRulesEventContent::new(JoinRule::Invite)), server_user, &room_id, &state_lock, @@ -154,15 +117,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + String::new(), + &RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared), + ), server_user, &room_id, &state_lock, @@ -174,15 +132,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomGuestAccessEventContent::new(GuestAccess::Forbidden)), server_user, &room_id, &state_lock, @@ -195,15 +145,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(room_name)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomNameEventContent::new(room_name)), server_user, &room_id, &state_lock, @@ -214,17 +156,12 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { + PduBuilder::state( + String::new(), + &RoomTopicEventContent { topic: format!("Manage {}", services.globals.server_name()), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, @@ -238,18 +175,13 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { + PduBuilder::state( + String::new(), + &RoomCanonicalAliasEventContent { alias: Some(alias.clone()), alt_aliases: Vec::new(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, @@ -266,17 +198,12 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPreviewUrls, - content: to_raw_value(&RoomPreviewUrlsEventContent { + PduBuilder::state( + String::new(), + &RoomPreviewUrlsEventContent { disabled: true, - }) - .expect("event is valid we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index 6e266ca9b..405da982e 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -9,11 +9,10 @@ use ruma::{ power_levels::RoomPowerLevelsEventContent, }, tag::{TagEvent, TagEventContent, TagInfo}, - RoomAccountDataEventType, TimelineEventType, + RoomAccountDataEventType, }, RoomId, UserId, }; -use serde_json::value::to_raw_value; use crate::pdu::PduBuilder; @@ -35,24 +34,7 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { self.services .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Invite)), server_user, &room_id, &state_lock, @@ -61,24 +43,7 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { self.services .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Join)), user_id, &room_id, &state_lock, @@ -91,18 +56,13 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { self.services .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { + PduBuilder::state( + String::new(), + &RoomPowerLevelsEventContent { users, ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, @@ -117,23 +77,18 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { } } + let welcome_message = String::from("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `!admin --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`"); + // Send welcome message - self.services.timeline.build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&RoomMessageEventContent::text_markdown( - String::from("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `!admin --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`"), - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ).await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder::timeline(&RoomMessageEventContent::text_markdown(welcome_message)), + server_user, + &room_id, + &state_lock, + ) + .await?; Ok(()) } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index da7f3cf4c..58cc012c2 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -15,13 +15,9 @@ pub use create::create_admin_room; use futures::{FutureExt, TryFutureExt}; use loole::{Receiver, Sender}; use ruma::{ - events::{ - room::message::{Relation, RoomMessageEventContent}, - TimelineEventType, - }, + events::room::message::{Relation, RoomMessageEventContent}, OwnedEventId, OwnedRoomId, RoomId, UserId, }; -use serde_json::value::to_raw_value; use tokio::sync::{Mutex, RwLock}; use crate::{account_data, globals, rooms, rooms::state::RoomMutexGuard, Dep}; @@ -285,20 +281,12 @@ impl Service { ) -> Result<()> { assert!(self.user_is_admin(user_id).await, "sender is not admin"); - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - timestamp: None, - }; - let state_lock = self.services.state.mutex.lock(room_id).await; + if let Err(e) = self .services .timeline - .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) + .build_and_append_pdu(PduBuilder::timeline(&content), user_id, room_id, &state_lock) .await { self.handle_response_error(e, room_id, user_id, &state_lock) @@ -313,23 +301,14 @@ impl Service { &self, e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard, ) -> Result<()> { error!("Failed to build and append admin room response PDU: \"{e}\""); - let error_room_message = RoomMessageEventContent::text_plain(format!( + let content = RoomMessageEventContent::text_plain(format!( "Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \ successfully, but we could not return the output." )); - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&error_room_message).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - timestamp: None, - }; - self.services .timeline - .build_and_append_pdu(response_pdu, user_id, room_id, state_lock) + .build_and_append_pdu(PduBuilder::timeline(&content), user_id, room_id, state_lock) .await?; Ok(()) diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 3855d92a2..19f1f1413 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -37,7 +37,6 @@ use ruma::{ ServerName, UserId, }; use serde::Deserialize; -use serde_json::value::to_raw_value; use self::data::Data; use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; @@ -353,21 +352,14 @@ impl Service { pub async fn user_can_invite( &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, ) -> bool { - let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) - .expect("Event content always serializes"); - - let new_event = PduBuilder { - event_type: ruma::events::TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(target_user.into()), - redacts: None, - timestamp: None, - }; - self.services .timeline - .create_hash_and_sign_event(new_event, sender, room_id, state_lock) + .create_hash_and_sign_event( + PduBuilder::state(target_user.into(), &RoomMemberEventContent::new(MembershipState::Invite)), + sender, + room_id, + state_lock, + ) .await .is_ok() } From 8ea2dccc9ad72df70555c8dc04ee85d6ed49f1a7 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 9 Jul 2024 17:23:00 +0000 Subject: [PATCH 066/245] sort rustfmt Signed-off-by: Jason Volk --- rustfmt.toml | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index 114677d49..fd912a193 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,28 +1,27 @@ -edition = "2021" - +array_width = 80 +chain_width = 60 +comment_width = 80 condense_wildcard_suffixes = true +edition = "2021" +fn_call_width = 80 +fn_params_layout = "Compressed" +fn_single_line = true format_code_in_doc_comments = true format_macro_bodies = true format_macro_matchers = true format_strings = true -hex_literal_case = "Upper" -max_width = 120 -tab_spaces = 4 -array_width = 80 -comment_width = 80 -wrap_comments = true -fn_params_layout = "Compressed" -fn_call_width = 80 -fn_single_line = true +group_imports = "StdExternalCrate" hard_tabs = true -match_block_trailing_comma = true +hex_literal_case = "Upper" imports_granularity = "Crate" +match_block_trailing_comma = true +max_width = 120 +newline_style = "Unix" normalize_comments = false reorder_impl_items = true reorder_imports = true -group_imports = "StdExternalCrate" -newline_style = "Unix" +tab_spaces = 4 use_field_init_shorthand = true use_small_heuristics = "Off" use_try_shorthand = true -chain_width = 60 +wrap_comments = true From c9c405facfcfd30c76e3b830929a4e4c90b930c2 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 29 Sep 2024 09:20:17 +0000 Subject: [PATCH 067/245] relax Sized bound for debug::type_name Signed-off-by: Jason Volk --- src/core/debug.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/debug.rs b/src/core/debug.rs index 1e36ca8e2..85574a2f3 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -86,11 +86,11 @@ pub fn panic_str(p: &Box) -> &'static str { p.downcast_ref::<&st #[inline(always)] #[must_use] -pub fn rttype_name(_: &T) -> &'static str { type_name::() } +pub fn rttype_name(_: &T) -> &'static str { type_name::() } #[inline(always)] #[must_use] -pub fn type_name() -> &'static str { std::any::type_name::() } +pub fn type_name() -> &'static str { std::any::type_name::() } #[must_use] #[inline] From 16f82b02a07110ae3f4133758d3a7e20ca2401ea Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 29 Sep 2024 04:18:47 +0000 Subject: [PATCH 068/245] add util to restore state on scope exit Signed-off-by: Jason Volk --- src/core/utils/defer.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/core/utils/defer.rs b/src/core/utils/defer.rs index 08477b6f5..29199700b 100644 --- a/src/core/utils/defer.rs +++ b/src/core/utils/defer.rs @@ -15,8 +15,14 @@ macro_rules! defer { }; ($body:expr) => { - $crate::defer! {{ - $body - }} + $crate::defer! {{ $body }} + }; +} + +#[macro_export] +macro_rules! scope_restore { + ($val:ident, $ours:expr) => { + let theirs = $crate::utils::exchange($val, $ours); + $crate::defer! {{ *$val = theirs; }}; }; } From a5e85727b5d1447a67e6ef970f5cc9d54f866f87 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 29 Sep 2024 09:01:57 +0000 Subject: [PATCH 069/245] add tuple access functor-macro Signed-off-by: Jason Volk --- src/core/utils/mod.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index c34691d2d..4dbecf91a 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -39,3 +39,10 @@ pub use self::{ #[inline] pub fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } + +#[macro_export] +macro_rules! at { + ($idx:tt) => { + |t| t.$idx + }; +} From 43b0bb6a5e62a9262abcad63431bf9ac0c2d60cc Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 7 Oct 2024 19:19:53 +0000 Subject: [PATCH 070/245] add non-allocating fixed-size random string generator Signed-off-by: Jason Volk --- src/core/utils/rand.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/core/utils/rand.rs b/src/core/utils/rand.rs index b80671eb9..d717c4bdc 100644 --- a/src/core/utils/rand.rs +++ b/src/core/utils/rand.rs @@ -3,6 +3,7 @@ use std::{ time::{Duration, SystemTime}, }; +use arrayvec::ArrayString; use rand::{thread_rng, Rng}; pub fn string(length: usize) -> String { @@ -13,6 +14,18 @@ pub fn string(length: usize) -> String { .collect() } +#[inline] +pub fn string_array() -> ArrayString { + let mut ret = ArrayString::::new(); + thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(LENGTH) + .map(char::from) + .for_each(|c| ret.push(c)); + + ret +} + #[inline] #[must_use] pub fn timepoint_secs(range: Range) -> SystemTime { From c40d20cb95283c1e03c72fec437c48b8debee678 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 25 Sep 2024 05:04:25 +0000 Subject: [PATCH 071/245] add macro util to determine if cargo build or check/clippy. Signed-off-by: Jason Volk --- src/macros/utils.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/macros/utils.rs b/src/macros/utils.rs index 197dd90e9..e4ffc622d 100644 --- a/src/macros/utils.rs +++ b/src/macros/utils.rs @@ -2,6 +2,16 @@ use syn::{parse_str, Expr, Generics, Lit, Meta}; use crate::Result; +pub(crate) fn is_cargo_build() -> bool { + std::env::args() + .find(|flag| flag.starts_with("--emit")) + .as_ref() + .and_then(|flag| flag.split_once('=')) + .map(|val| val.1.split(',')) + .and_then(|mut vals| vals.find(|elem| *elem == "link")) + .is_some() +} + pub(crate) fn get_named_generics(args: &[Meta], name: &str) -> Result { const DEFAULT: &str = "<>"; From 2a59a56eaa6d63c7db6634b1c1662d7f34dd7598 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 8 Sep 2024 22:17:02 +0000 Subject: [PATCH 072/245] initial example-config generator Signed-off-by: Jason Volk --- src/core/config/mod.rs | 2 + src/macros/config.rs | 98 ++++++++++++++++++++++++++++++++++++++++++ src/macros/mod.rs | 8 +++- 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 src/macros/config.rs diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 126b3123e..64e1c9ba5 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -5,6 +5,7 @@ use std::{ path::PathBuf, }; +use conduit_macros::config_example_generator; use either::{ Either, Either::{Left, Right}, @@ -27,6 +28,7 @@ pub mod check; pub mod proxy; /// all the config options for conduwuit +#[config_example_generator] #[derive(Clone, Debug, Deserialize)] #[allow(clippy::struct_excessive_bools)] pub struct Config { diff --git a/src/macros/config.rs b/src/macros/config.rs new file mode 100644 index 000000000..6d29c21fa --- /dev/null +++ b/src/macros/config.rs @@ -0,0 +1,98 @@ +use std::fmt::Write; + +use proc_macro::TokenStream; +use quote::ToTokens; +use syn::{Expr, ExprLit, Field, Fields, FieldsNamed, ItemStruct, Lit, Meta, MetaNameValue, Type, TypePath}; + +use crate::{utils::is_cargo_build, Result}; + +#[allow(clippy::needless_pass_by_value)] +pub(super) fn example_generator(input: ItemStruct, args: &[Meta]) -> Result { + if is_cargo_build() { + generate_example(&input, args)?; + } + + Ok(input.to_token_stream().into()) +} + +#[allow(clippy::needless_pass_by_value)] +#[allow(unused_variables)] +fn generate_example(input: &ItemStruct, _args: &[Meta]) -> Result<()> { + if let Fields::Named(FieldsNamed { + named, + .. + }) = &input.fields + { + for field in named { + let Some(ident) = &field.ident else { + continue; + }; + + let Some(doc) = get_doc_comment(field) else { + continue; + }; + + let Some(type_name) = get_type_name(field) else { + continue; + }; + + //println!("{:?} {type_name:?}\n{doc}", ident.to_string()); + } + } + + Ok(()) +} + +fn get_doc_comment(field: &Field) -> Option { + let mut out = String::new(); + for attr in &field.attrs { + let Meta::NameValue(MetaNameValue { + path, + value, + .. + }) = &attr.meta + else { + continue; + }; + + if !path + .segments + .iter() + .next() + .is_some_and(|s| s.ident == "doc") + { + continue; + } + + let Expr::Lit(ExprLit { + lit, + .. + }) = &value + else { + continue; + }; + + let Lit::Str(token) = &lit else { + continue; + }; + + writeln!(&mut out, "# {}", token.value()).expect("wrote to output string buffer"); + } + + (!out.is_empty()).then_some(out) +} + +fn get_type_name(field: &Field) -> Option { + let Type::Path(TypePath { + path, + .. + }) = &field.ty + else { + return None; + }; + + path.segments + .iter() + .next() + .map(|segment| segment.ident.to_string()) +} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index d32cda71c..1aa1e24fd 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,5 +1,6 @@ mod admin; mod cargo; +mod config; mod debug; mod implement; mod refutable; @@ -9,7 +10,7 @@ mod utils; use proc_macro::TokenStream; use syn::{ parse::{Parse, Parser}, - parse_macro_input, Error, Item, ItemConst, ItemEnum, ItemFn, Meta, + parse_macro_input, Error, Item, ItemConst, ItemEnum, ItemFn, ItemStruct, Meta, }; pub(crate) type Result = std::result::Result; @@ -47,6 +48,11 @@ pub fn implement(args: TokenStream, input: TokenStream) -> TokenStream { attribute_macro::(args, input, implement::implement) } +#[proc_macro_attribute] +pub fn config_example_generator(args: TokenStream, input: TokenStream) -> TokenStream { + attribute_macro::(args, input, config::example_generator) +} + fn attribute_macro(args: TokenStream, input: TokenStream, func: F) -> TokenStream where F: Fn(I, &[Meta]) -> Result, From f67cfcd5353bf112760f89a9451aafc2ba2d9fde Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 4 Oct 2024 19:10:20 +0000 Subject: [PATCH 073/245] cleanup Config::load() Signed-off-by: Jason Volk --- src/core/config/mod.rs | 42 +++++++++++++++++------------------------- src/main/server.rs | 2 +- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 64e1c9ba5..40c900e56 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1,3 +1,6 @@ +pub mod check; +pub mod proxy; + use std::{ collections::{BTreeMap, BTreeSet}, fmt, @@ -22,10 +25,7 @@ use url::Url; pub use self::check::check; use self::proxy::ProxyConfig; -use crate::{error::Error, utils::sys, Err, Result}; - -pub mod check; -pub mod proxy; +use crate::{err, error::Error, utils::sys, Result}; /// all the config options for conduwuit #[config_example_generator] @@ -441,34 +441,26 @@ const DEPRECATED_KEYS: &[&str; 9] = &[ impl Config { /// Pre-initialize config - pub fn load(paths: &Option>) -> Result { - let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { - Figment::new().merge(Toml::file(config_file_env).nested()) - } else if let Some(config_file_arg) = Env::var("CONDUWUIT_CONFIG") { - Figment::new().merge(Toml::file(config_file_arg).nested()) - } else if let Some(config_file_args) = paths { - let mut figment = Figment::new(); - - for config in config_file_args { - figment = figment.merge(Toml::file(config).nested()); - } + pub fn load(paths: Option<&[PathBuf]>) -> Result { + let paths_files = paths.into_iter().flatten().map(Toml::file); - figment - } else { - Figment::new() - }; + let envs = [Env::var("CONDUIT_CONFIG"), Env::var("CONDUWUIT_CONFIG")]; + let envs_files = envs.into_iter().flatten().map(Toml::file); - Ok(raw_config + let config = envs_files + .chain(paths_files) + .fold(Figment::new(), |config, file| config.merge(file.nested())) .merge(Env::prefixed("CONDUIT_").global().split("__")) - .merge(Env::prefixed("CONDUWUIT_").global().split("__"))) + .merge(Env::prefixed("CONDUWUIT_").global().split("__")); + + Ok(config) } /// Finalize config pub fn new(raw_config: &Figment) -> Result { - let config = match raw_config.extract::() { - Err(e) => return Err!("There was a problem with your configuration file: {e}"), - Ok(config) => config, - }; + let config = raw_config + .extract::() + .map_err(|e| err!("There was a problem with your configuration file: {e}"))?; // don't start if we're listening on both UNIX sockets and TCP at same time check::is_dual_listening(raw_config)?; diff --git a/src/main/server.rs b/src/main/server.rs index e435b2f44..4813d586c 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -24,7 +24,7 @@ pub(crate) struct Server { impl Server { pub(crate) fn build(args: &Args, runtime: Option<&runtime::Handle>) -> Result, Error> { - let raw_config = Config::load(&args.config)?; + let raw_config = Config::load(args.config.as_deref())?; let raw_config = crate::clap::update(raw_config, args)?; let config = Config::new(&raw_config)?; From fc4d109f35d2cfb54ae3a463cb66e318e5947510 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 5 Oct 2024 19:39:27 +0000 Subject: [PATCH 074/245] add document comments to config items Signed-off-by: Jason Volk --- src/core/config/mod.rs | 845 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 842 insertions(+), 3 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 40c900e56..b5e07da23 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -31,221 +31,768 @@ use crate::{err, error::Error, utils::sys, Result}; #[config_example_generator] #[derive(Clone, Debug, Deserialize)] #[allow(clippy::struct_excessive_bools)] +#[allow(rustdoc::broken_intra_doc_links, rustdoc::bare_urls)] pub struct Config { - /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) + /// The server_name is the pretty name of this server. It is used as a + /// suffix for user and room ids. Examples: matrix.org, conduit.rs + /// + /// The Conduit server needs all /_matrix/ requests to be reachable at + /// https://your.server.name/ on port 443 (client-server) and 8448 (federation). + /// + /// If that's not possible for you, you can create /.well-known files to + /// redirect requests (delegation). See + /// https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixclient + /// and + /// https://spec.matrix.org/v1.9/server-server-api/#getwell-knownmatrixserver + /// for more information. + /// + /// YOU NEED TO EDIT THIS + pub server_name: OwnedServerName, + + /// Database backend: Only rocksdb is supported. + /// default address (IPv4 or IPv6) conduwuit will listen on. Generally you + /// want this to be localhost (127.0.0.1 / ::1). If you are using Docker or + /// a container NAT networking setup, you likely need this to be 0.0.0.0. + /// To listen multiple addresses, specify a vector e.g. ["127.0.0.1", "::1"] + /// + /// default if unspecified is both IPv4 and IPv6 localhost: ["127.0.0.1", + /// "::1"] #[serde(default = "default_address")] address: ListeningAddr, - /// default TCP port(s) conduwuit will listen on + + /// The port(s) conduwuit will be running on. You need to set up a reverse + /// proxy such as Caddy or Nginx so all requests to /_matrix on port 443 + /// and 8448 will be forwarded to the conduwuit instance running on this + /// port Docker users: Don't change this, you'll need to map an external + /// port to this. To listen on multiple ports, specify a vector e.g. [8080, + /// 8448] + /// + /// default if unspecified is 8008 #[serde(default = "default_port")] port: ListeningPort, + pub tls: Option, + + /// Uncomment unix_socket_path to listen on a UNIX socket at the specified + /// path. If listening on a UNIX socket, you must remove/comment the + /// 'address' key if defined and add your reverse proxy to the 'conduwuit' + /// group, unless world RW permissions are specified with unix_socket_perms + /// (666 minimum). pub unix_socket_path: Option, + #[serde(default = "default_unix_socket_perms")] pub unix_socket_perms: u32, - pub server_name: OwnedServerName, + #[serde(default = "default_database_backend")] pub database_backend: String, + + /// This is the only directory where conduwuit will save its data, including + /// media. Note: this was previously "/var/lib/matrix-conduit" pub database_path: PathBuf, + pub database_backup_path: Option, + #[serde(default = "default_database_backups_to_keep")] pub database_backups_to_keep: i16, + + /// Set this to any float value in megabytes for conduwuit to tell the + /// database engine that this much memory is available for database-related + /// caches. May be useful if you have significant memory to spare to + /// increase performance. Defaults to 256.0 #[serde(default = "default_db_cache_capacity_mb")] pub db_cache_capacity_mb: f64, + + /// Option to control adding arbitrary text to the end of the user's + /// displayname upon registration with a space before the text. This was the + /// lightning bolt emoji option, just replaced with support for adding your + /// own custom text or emojis. To disable, set this to "" (an empty string) + /// Defaults to "🏳️⚧️" (trans pride flag) #[serde(default = "default_new_user_displayname_suffix")] pub new_user_displayname_suffix: String, + + /// If enabled, conduwuit will send a simple GET request periodically to + /// `https://pupbrain.dev/check-for-updates/stable` for any new + /// announcements made. Despite the name, this is not an update check + /// endpoint, it is simply an announcement check endpoint. Defaults to + /// false. #[serde(default)] pub allow_check_for_updates: bool, #[serde(default = "default_pdu_cache_capacity")] pub pdu_cache_capacity: u32, + + /// Set this to any float value to multiply conduwuit's in-memory LRU caches + /// with. May be useful if you have significant memory to spare to increase + /// performance. + /// + /// This was previously called `conduit_cache_capacity_modifier` + /// + /// Defaults to 1.0. #[serde(default = "default_cache_capacity_modifier", alias = "conduit_cache_capacity_modifier")] pub cache_capacity_modifier: f64, + #[serde(default = "default_auth_chain_cache_capacity")] pub auth_chain_cache_capacity: u32, + #[serde(default = "default_shorteventid_cache_capacity")] pub shorteventid_cache_capacity: u32, + #[serde(default = "default_eventidshort_cache_capacity")] pub eventidshort_cache_capacity: u32, + #[serde(default = "default_shortstatekey_cache_capacity")] pub shortstatekey_cache_capacity: u32, + #[serde(default = "default_statekeyshort_cache_capacity")] pub statekeyshort_cache_capacity: u32, + #[serde(default = "default_server_visibility_cache_capacity")] pub server_visibility_cache_capacity: u32, + #[serde(default = "default_user_visibility_cache_capacity")] pub user_visibility_cache_capacity: u32, + #[serde(default = "default_stateinfo_cache_capacity")] pub stateinfo_cache_capacity: u32, + #[serde(default = "default_roomid_spacehierarchy_cache_capacity")] pub roomid_spacehierarchy_cache_capacity: u32, + /// Maximum entries stored in DNS memory-cache. The size of an entry may + /// vary so please take care if raising this value excessively. Only + /// decrease this when using an external DNS cache. Please note + /// that systemd does *not* count as an external cache, even when configured + /// to do so. #[serde(default = "default_dns_cache_entries")] pub dns_cache_entries: u32, + + /// Minimum time-to-live in seconds for entries in the DNS cache. The + /// default may appear high to most administrators; this is by design. Only + /// decrease this if you are using an external DNS cache. #[serde(default = "default_dns_min_ttl")] pub dns_min_ttl: u64, + + /// Minimum time-to-live in seconds for NXDOMAIN entries in the DNS cache. + /// This value is critical for the server to federate efficiently. + /// NXDOMAIN's are assumed to not be returning to the federation + /// and aggressively cached rather than constantly rechecked. + /// + /// Defaults to 3 days as these are *very rarely* false negatives. #[serde(default = "default_dns_min_ttl_nxdomain")] pub dns_min_ttl_nxdomain: u64, + + /// Number of retries after a timeout. #[serde(default = "default_dns_attempts")] pub dns_attempts: u16, + + /// The number of seconds to wait for a reply to a DNS query. Please note + /// that recursive queries can take up to several seconds for some domains, + /// so this value should not be too low. #[serde(default = "default_dns_timeout")] pub dns_timeout: u64, + + /// Fallback to TCP on DNS errors. Set this to false if unsupported by + /// nameserver. #[serde(default = "true_fn")] pub dns_tcp_fallback: bool, + + /// Enable to query all nameservers until the domain is found. Referred to + /// as "trust_negative_responses" in hickory_reso> This can avoid useless + /// DNS queries if the first nameserver responds with NXDOMAIN or an empty + /// NOERROR response. + /// + /// The default is to query one nameserver and stop (false). #[serde(default = "true_fn")] pub query_all_nameservers: bool, + + /// Enables using *only* TCP for querying your specified nameservers instead + /// of UDP. + /// + /// You very likely do *not* want this. hickory-resolver already falls back + /// to TCP on UDP errors. Defaults to false #[serde(default)] pub query_over_tcp_only: bool, + + /// DNS A/AAAA record lookup strategy + /// + /// Takes a number of one of the following options: + /// 1 - Ipv4Only (Only query for A records, no AAAA/IPv6) + /// 2 - Ipv6Only (Only query for AAAA records, no A/IPv4) + /// 3 - Ipv4AndIpv6 (Query for A and AAAA records in parallel, uses whatever + /// returns a successful response first) 4 - Ipv6thenIpv4 (Query for AAAA + /// record, if that fails then query the A record) 5 - Ipv4thenIpv6 (Query + /// for A record, if that fails then query the AAAA record) + /// + /// If you don't have IPv6 networking, then for better performance it may be + /// suitable to set this to Ipv4Only (1) as you will never ever use the + /// AAAA record contents even if the AAAA record is successful instead of + /// the A record. + /// + /// Defaults to 5 - Ipv4ThenIpv6 as this is the most compatible and IPv4 + /// networking is currently the most prevalent. #[serde(default = "default_ip_lookup_strategy")] pub ip_lookup_strategy: u8, + /// Max request size for file uploads #[serde(default = "default_max_request_size")] pub max_request_size: usize, + #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, + /// Default/base connection timeout. + /// This is used only by URL previews and update/news endpoint checks + /// + /// Defaults to 10 seconds #[serde(default = "default_request_conn_timeout")] pub request_conn_timeout: u64, + + /// Default/base request timeout. The time waiting to receive more data from + /// another server. This is used only by URL previews, update/news, and + /// misc endpoint checks + /// + /// Defaults to 35 seconds #[serde(default = "default_request_timeout")] pub request_timeout: u64, + + /// Default/base request total timeout. The time limit for a whole request. + /// This is set very high to not cancel healthy requests while serving as a + /// backstop. This is used only by URL previews and update/news endpoint + /// checks + /// + /// Defaults to 320 seconds #[serde(default = "default_request_total_timeout")] pub request_total_timeout: u64, + + /// Default/base idle connection pool timeout + /// This is used only by URL previews and update/news endpoint checks + /// + /// Defaults to 5 seconds #[serde(default = "default_request_idle_timeout")] pub request_idle_timeout: u64, + + /// Default/base max idle connections per host + /// This is used only by URL previews and update/news endpoint checks + /// + /// Defaults to 1 as generally the same open connection can be re-used #[serde(default = "default_request_idle_per_host")] pub request_idle_per_host: u16, + + /// Federation well-known resolution connection timeout + /// + /// Defaults to 6 seconds #[serde(default = "default_well_known_conn_timeout")] pub well_known_conn_timeout: u64, + + /// Federation HTTP well-known resolution request timeout + /// + /// Defaults to 10 seconds #[serde(default = "default_well_known_timeout")] pub well_known_timeout: u64, + + /// Federation client request timeout + /// You most definitely want this to be high to account for extremely large + /// room joins, slow homeservers, your own resources etc. + /// + /// Defaults to 300 seconds #[serde(default = "default_federation_timeout")] pub federation_timeout: u64, + + /// Federation client idle connection pool timeout + /// + /// Defaults to 25 seconds #[serde(default = "default_federation_idle_timeout")] pub federation_idle_timeout: u64, + + /// Federation client max idle connections per host + /// + /// Defaults to 1 as generally the same open connection can be re-used #[serde(default = "default_federation_idle_per_host")] pub federation_idle_per_host: u16, + + /// Federation sender request timeout + /// The time it takes for the remote server to process sent transactions can + /// take a while. + /// + /// Defaults to 180 seconds #[serde(default = "default_sender_timeout")] pub sender_timeout: u64, + + /// Federation sender idle connection pool timeout + /// + /// Defaults to 180 seconds #[serde(default = "default_sender_idle_timeout")] pub sender_idle_timeout: u64, + + /// Federation sender transaction retry backoff limit + /// + /// Defaults to 86400 seconds #[serde(default = "default_sender_retry_backoff_limit")] pub sender_retry_backoff_limit: u64, + + /// Appservice URL request connection timeout + /// + /// Defaults to 35 seconds as generally appservices are hosted within the + /// same network #[serde(default = "default_appservice_timeout")] pub appservice_timeout: u64, + + /// Appservice URL idle connection pool timeout + /// + /// Defaults to 300 seconds #[serde(default = "default_appservice_idle_timeout")] pub appservice_idle_timeout: u64, + + /// Notification gateway pusher idle connection pool timeout + /// + /// Defaults to 15 seconds #[serde(default = "default_pusher_idle_timeout")] pub pusher_idle_timeout: u64, + /// Enables registration. If set to false, no users can register on this + /// server. + /// + /// If set to true without a token configured, users can register with no + /// form of 2nd- step only if you set + /// `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to + /// true in your config. + /// + /// If you would like registration only via token reg, please configure + /// `registration_token` or `registration_token_file`. #[serde(default)] pub allow_registration: bool, + #[serde(default)] pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, + + /// A static registration token that new users will have to provide when + /// creating an account. If unset and `allow_registration` is true, + /// registration is open without any condition. YOU NEED TO EDIT THIS. pub registration_token: Option, + + /// Path to a file on the system that gets read for the registration token + /// + /// conduwuit must be able to access the file, and it must not be empty + /// + /// no default pub registration_token_file: Option, + + /// controls whether encrypted rooms and events are allowed (default true) #[serde(default = "true_fn")] pub allow_encryption: bool, + + /// controls whether federation is allowed or not + /// defaults to true #[serde(default = "true_fn")] pub allow_federation: bool, + #[serde(default)] pub federation_loopback: bool, + + /// Set this to true to allow your server's public room directory to be + /// federated. Set this to false to protect against /publicRooms spiders, + /// but will forbid external users from viewing your server's public room + /// directory. If federation is disabled entirely (`allow_federation`), + /// this is inherently false. #[serde(default)] pub allow_public_room_directory_over_federation: bool, + + /// Set this to true to allow your server's public room directory to be + /// queried without client authentication (access token) through the Client + /// APIs. Set this to false to protect against /publicRooms spiders. #[serde(default)] pub allow_public_room_directory_without_auth: bool, + + /// allow guests/unauthenticated users to access TURN credentials + /// + /// this is the equivalent of Synapse's `turn_allow_guests` config option. + /// this allows any unauthenticated user to call + /// `/_matrix/client/v3/voip/turnServer`. + /// + /// defaults to false #[serde(default)] pub turn_allow_guests: bool, + + /// Set this to true to lock down your server's public room directory and + /// only allow admins to publish rooms to the room directory. Unpublishing + /// is still allowed by all users with this enabled. + /// + /// Defaults to false #[serde(default)] pub lockdown_public_room_directory: bool, + + /// Set this to true to allow federating device display names / allow + /// external users to see your device display name. If federation is + /// disabled entirely (`allow_federation`), this is inherently false. For + /// privacy, this is best disabled. #[serde(default)] pub allow_device_name_federation: bool, + + /// Config option to allow or disallow incoming federation requests that + /// obtain the profiles of our local users from + /// `/_matrix/federation/v1/query/profile` + /// + /// This is inherently false if `allow_federation` is disabled + /// + /// Defaults to true #[serde(default = "true_fn")] pub allow_profile_lookup_federation_requests: bool, + + /// controls whether users are allowed to create rooms. + /// appservices and admins are always allowed to create rooms + /// defaults to true #[serde(default = "true_fn")] pub allow_room_creation: bool, + + /// Set to false to disable users from joining or creating room versions + /// that aren't 100% officially supported by conduwuit. + /// conduwuit officially supports room versions 6 - 10. conduwuit has + /// experimental/unstable support for 3 - 5, and 11. Defaults to true. #[serde(default = "true_fn")] pub allow_unstable_room_versions: bool, + #[serde(default = "default_default_room_version")] pub default_room_version: RoomVersionId, + #[serde(default)] pub well_known: WellKnownConfig, + #[serde(default)] pub allow_jaeger: bool, + #[serde(default = "default_jaeger_filter")] pub jaeger_filter: String, + + /// If the 'perf_measurements' feature is enabled, enables collecting folded + /// stack trace profile of tracing spans using tracing_flame. The resulting + /// profile can be visualized with inferno[1], speedscope[2], or a number of + /// other tools. [1]: https://github.com/jonhoo/inferno + /// [2]: www.speedscope.app #[serde(default)] pub tracing_flame: bool, + #[serde(default = "default_tracing_flame_filter")] pub tracing_flame_filter: String, + #[serde(default = "default_tracing_flame_output_path")] pub tracing_flame_output_path: String, + #[serde(default)] pub proxy: ProxyConfig, + pub jwt_secret: Option, + + /// Servers listed here will be used to gather public keys of other servers + /// (notary trusted key servers). + /// + /// (Currently, conduwuit doesn't support batched key requests, so this list + /// should only contain other Synapse servers) Defaults to `matrix.org` #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec, + + /// Option to control whether conduwuit will query your list of trusted + /// notary key servers (`trusted_servers`) for remote homeserver signing + /// keys it doesn't know *first*, or query the individual servers first + /// before falling back to the trusted key servers. + /// + /// The former/default behaviour makes federated/remote rooms joins + /// generally faster because we're querying a single (or list of) server + /// that we know works, is reasonably fast, and is reliable for just about + /// all the homeserver signing keys in the room. Querying individual + /// servers may take longer depending on the general infrastructure of + /// everyone in there, how many dead servers there are, etc. + /// + /// However, this does create an increased reliance on one single or + /// multiple large entities as `trusted_servers` should generally + /// contain long-term and large servers who know a very large number of + /// homeservers. + /// + /// If you don't know what any of this means, leave this and + /// `trusted_servers` alone to their defaults. + /// + /// Defaults to true as this is the fastest option for federation. #[serde(default = "true_fn")] pub query_trusted_key_servers_first: bool, + + /// max log level for conduwuit. allows debug, info, warn, or error + /// see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives + /// **Caveat**: + /// For release builds, the tracing crate is configured to only implement + /// levels higher than error to avoid unnecessary overhead in the compiled + /// binary from trace macros. For debug builds, this restriction is not + /// applied. + /// + /// Defaults to "info" #[serde(default = "default_log")] pub log: String, + + /// controls whether logs will be outputted with ANSI colours + /// + /// defaults to true #[serde(default = "true_fn", alias = "log_colours")] pub log_colors: bool, + + /// OpenID token expiration/TTL in seconds + /// + /// These are the OpenID tokens that are primarily used for Matrix account + /// integrations, *not* OIDC/OpenID Connect/etc + /// + /// Defaults to 3600 (1 hour) #[serde(default = "default_openid_token_ttl")] pub openid_token_ttl: u64, + + /// TURN username to provide the client + /// + /// no default #[serde(default)] pub turn_username: String, + + /// TURN password to provide the client + /// + /// no default #[serde(default)] pub turn_password: String, + + /// vector list of TURN URIs/servers to use + /// + /// replace "example.turn.uri" with your TURN domain, such as the coturn + /// "realm". if using TURN over TLS, replace "turn:" with "turns:" + /// + /// No default #[serde(default = "Vec::new")] pub turn_uris: Vec, + + /// TURN secret to use for generating the HMAC-SHA1 hash apart of username + /// and password generation + /// + /// this is more secure, but if needed you can use traditional + /// username/password below. + /// + /// no default #[serde(default)] pub turn_secret: String, + + /// TURN secret to use that's read from the file path specified + /// + /// this takes priority over "turn_secret" first, and falls back to + /// "turn_secret" if invalid or failed to open. + /// + /// no default pub turn_secret_file: Option, + + /// TURN TTL + /// + /// Default is 86400 seconds #[serde(default = "default_turn_ttl")] pub turn_ttl: u64, + /// List/vector of room **IDs** that conduwuit will make newly registered + /// users join. The room IDs specified must be rooms that you have joined + /// at least once on the server, and must be public. + /// + /// No default. #[serde(default = "Vec::new")] pub auto_join_rooms: Vec, + + /// Config option to automatically deactivate the account of any user who + /// attempts to join a: + /// - banned room + /// - forbidden room alias + /// - room alias or ID with a forbidden server name + /// + /// This may be useful if all your banned lists consist of toxic rooms or + /// servers that no good faith user would ever attempt to join, and + /// to automatically remediate the problem without any admin user + /// intervention. + /// + /// This will also make the user leave all rooms. Federation (e.g. remote + /// room invites) are ignored here. + /// + /// Defaults to false as rooms can be banned for non-moderation-related + /// reasons #[serde(default)] pub auto_deactivate_banned_room_attempts: bool, + /// RocksDB log level. This is not the same as conduwuit's log level. This + /// is the log level for the RocksDB engine/library which show up in your + /// database folder/path as `LOG` files. Defaults to error. conduwuit will + /// typically log RocksDB errors as normal. #[serde(default = "default_rocksdb_log_level")] pub rocksdb_log_level: String, + #[serde(default)] pub rocksdb_log_stderr: bool, + + /// Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB. #[serde(default = "default_rocksdb_max_log_file_size")] pub rocksdb_max_log_file_size: usize, + + /// Time in seconds before RocksDB will forcibly rotate logs. Defaults to 0. #[serde(default = "default_rocksdb_log_time_to_roll")] pub rocksdb_log_time_to_roll: usize, + + /// Set this to true to use RocksDB config options that are tailored to HDDs + /// (slower device storage) + /// + /// It is worth noting that by default, conduwuit will use RocksDB with + /// Direct IO enabled. *Generally* speaking this improves performance as it + /// bypasses buffered I/O (system page cache). However there is a potential + /// chance that Direct IO may cause issues with database operations if your + /// setup is uncommon. This has been observed with FUSE filesystems, and + /// possibly ZFS filesystem. RocksDB generally deals/corrects these issues + /// but it cannot account for all setups. If you experience any weird + /// RocksDB issues, try enabling this option as it turns off Direct IO and + /// feel free to report in the conduwuit Matrix room if this option fixes + /// your DB issues. See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. + /// + /// Defaults to false #[serde(default)] pub rocksdb_optimize_for_spinning_disks: bool, + + /// Enables direct-io to increase database performance. This is enabled by + /// default. Set this option to false if the database resides on a + /// filesystem which does not support direct-io. #[serde(default = "true_fn")] pub rocksdb_direct_io: bool, + + /// Amount of threads that RocksDB will use for parallelism on database + /// operatons such as cleanup, sync, flush, compaction, etc. Set to 0 to use + /// all your logical threads. + /// + /// Defaults to your CPU logical thread count. #[serde(default = "default_rocksdb_parallelism_threads")] pub rocksdb_parallelism_threads: usize, + + /// Maximum number of LOG files RocksDB will keep. This must *not* be set to + /// 0. It must be at least 1. Defaults to 3 as these are not very useful. #[serde(default = "default_rocksdb_max_log_files")] pub rocksdb_max_log_files: usize, + + /// Type of RocksDB database compression to use. + /// Available options are "zstd", "zlib", "bz2", "lz4", or "none" + /// It is best to use ZSTD as an overall good balance between + /// speed/performance, storage, IO amplification, and CPU usage. + /// For more performance but less compression (more storage used) and less + /// CPU usage, use LZ4. See https://github.com/facebook/rocksdb/wiki/Compression for more details. + /// + /// "none" will disable compression. + /// + /// Defaults to "zstd" #[serde(default = "default_rocksdb_compression_algo")] pub rocksdb_compression_algo: String, + + /// Level of compression the specified compression algorithm for RocksDB to + /// use. Default is 32767, which is internally read by RocksDB as the + /// default magic number and translated to the library's default + /// compression level as they all differ. + /// See their `kDefaultCompressionLevel`. #[serde(default = "default_rocksdb_compression_level")] pub rocksdb_compression_level: i32, + + /// Level of compression the specified compression algorithm for the + /// bottommost level/data for RocksDB to use. Default is 32767, which is + /// internally read by RocksDB as the default magic number and translated + /// to the library's default compression level as they all differ. + /// See their `kDefaultCompressionLevel`. + /// + /// Since this is the bottommost level (generally old and least used data), + /// it may be desirable to have a very high compression level here as it's + /// lesss likely for this data to be used. Research your chosen compression + /// algorithm. #[serde(default = "default_rocksdb_bottommost_compression_level")] pub rocksdb_bottommost_compression_level: i32, + + /// Whether to enable RocksDB "bottommost_compression". + /// At the expense of more CPU usage, this will further compress the + /// database to reduce more storage. It is recommended to use ZSTD + /// compression with this for best compression results. See https://github.com/facebook/rocksdb/wiki/Compression for more details. + /// + /// Defaults to false as this uses more CPU when compressing. #[serde(default)] pub rocksdb_bottommost_compression: bool, + + /// Database recovery mode (for RocksDB WAL corruption) + /// + /// Use this option when the server reports corruption and refuses to start. + /// Set mode 2 (PointInTime) to cleanly recover from this corruption. The + /// server will continue from the last good state, several seconds or + /// minutes prior to the crash. Clients may have to run "clear-cache & + /// reload" to account for the rollback. Upon success, you may reset the + /// mode back to default and restart again. Please note in some cases the + /// corruption error may not be cleared for at least 30 minutes of + /// operation in PointInTime mode. + /// + /// As a very last ditch effort, if PointInTime does not fix or resolve + /// anything, you can try mode 3 (SkipAnyCorruptedRecord) but this will + /// leave the server in a potentially inconsistent state. + /// + /// The default mode 1 (TolerateCorruptedTailRecords) will automatically + /// drop the last entry in the database if corrupted during shutdown, but + /// nothing more. It is extraordinarily unlikely this will desynchronize + /// clients. To disable any form of silent rollback set mode 0 + /// (AbsoluteConsistency). + /// + /// The options are: + /// 0 = AbsoluteConsistency + /// 1 = TolerateCorruptedTailRecords (default) + /// 2 = PointInTime (use me if trying to recover) + /// 3 = SkipAnyCorruptedRecord (you now voided your Conduwuit warranty) + /// + /// See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information + /// + /// Defaults to 1 (TolerateCorruptedTailRecords) #[serde(default = "default_rocksdb_recovery_mode")] pub rocksdb_recovery_mode: u8, + + /// Database repair mode (for RocksDB SST corruption) + /// + /// Use this option when the server reports corruption while running or + /// panics. If the server refuses to start use the recovery mode options + /// first. Corruption errors containing the acronym 'SST' which occur after + /// startup will likely require this option. + /// + /// - Backing up your database directory is recommended prior to running the + /// repair. + /// - Disabling repair mode and restarting the server is recommended after + /// running the repair. + /// + /// Defaults to false #[serde(default)] pub rocksdb_repair: bool, + #[serde(default)] pub rocksdb_read_only: bool, + #[serde(default)] pub rocksdb_secondary: bool, + + /// Enables idle CPU priority for compaction thread. This is not enabled by + /// default to prevent compaction from falling too far behind on busy + /// systems. #[serde(default)] pub rocksdb_compaction_prio_idle: bool, + + /// Enables idle IO priority for compaction thread. This prevents any + /// unexpected lag in the server's operation and is usually a good idea. + /// Enabled by default. #[serde(default = "true_fn")] pub rocksdb_compaction_ioprio_idle: bool, + #[serde(default = "true_fn")] pub rocksdb_compaction: bool, + + /// Level of statistics collection. Some admin commands to display database + /// statistics may require this option to be set. Database performance may + /// be impacted by higher settings. + /// + /// Option is a number ranging from 0 to 6: + /// 0 = No statistics. + /// 1 = No statistics in release mode (default). + /// 2 to 3 = Statistics with no performance impact. + /// 3 to 5 = Statistics with possible performance impact. + /// 6 = All statistics. + /// + /// Defaults to 1 (No statistics, except in debug-mode) #[serde(default = "default_rocksdb_stats_level")] pub rocksdb_stats_level: u8, @@ -254,128 +801,420 @@ pub struct Config { #[serde(default = "default_notification_push_path")] pub notification_push_path: String, + /// Config option to control local (your server only) presence + /// updates/requests. Defaults to true. Note that presence on conduwuit is + /// very fast unlike Synapse's. If using outgoing presence, this MUST be + /// enabled. #[serde(default = "true_fn")] pub allow_local_presence: bool, + + /// Config option to control incoming federated presence updates/requests. + /// Defaults to true. This option receives presence updates from other + /// servers, but does not send any unless `allow_outgoing_presence` is true. + /// Note that presence on conduwuit is very fast unlike Synapse's. #[serde(default = "true_fn")] pub allow_incoming_presence: bool, + + /// Config option to control outgoing presence updates/requests. Defaults to + /// true. This option sends presence updates to other servers, but does not + /// receive any unless `allow_incoming_presence` is true. + /// Note that presence on conduwuit is very fast unlike Synapse's. + /// If using outgoing presence, you MUST enable `allow_local_presence` as + /// well. #[serde(default = "true_fn")] pub allow_outgoing_presence: bool, + + /// Config option to control how many seconds before presence updates that + /// you are idle. Defaults to 5 minutes. #[serde(default = "default_presence_idle_timeout_s")] pub presence_idle_timeout_s: u64, + + /// Config option to control how many seconds before presence updates that + /// you are offline. Defaults to 30 minutes. #[serde(default = "default_presence_offline_timeout_s")] pub presence_offline_timeout_s: u64, + + /// Config option to enable the presence idle timer for remote users. + /// Disabling is offered as an optimization for servers participating in + /// many large rooms or when resources are limited. Disabling it may cause + /// incorrect presence states (i.e. stuck online) to be seen for some + /// remote users. Defaults to true. #[serde(default = "true_fn")] pub presence_timeout_remote_users: bool, + /// Config option to control whether we should receive remote incoming read + /// receipts. Defaults to true. #[serde(default = "true_fn")] pub allow_incoming_read_receipts: bool, + + /// Config option to control whether we should send read receipts to remote + /// servers. Defaults to true. #[serde(default = "true_fn")] pub allow_outgoing_read_receipts: bool, + /// Config option to control outgoing typing updates to federation. Defaults + /// to true. #[serde(default = "true_fn")] pub allow_outgoing_typing: bool, + + /// Config option to control incoming typing updates from federation. + /// Defaults to true. #[serde(default = "true_fn")] pub allow_incoming_typing: bool, + + /// Config option to control maximum time federation user can indicate + /// typing. #[serde(default = "default_typing_federation_timeout_s")] pub typing_federation_timeout_s: u64, + + /// Config option to control minimum time local client can indicate typing. + /// This does not override a client's request to stop typing. It only + /// enforces a minimum value in case of no stop request. #[serde(default = "default_typing_client_timeout_min_s")] pub typing_client_timeout_min_s: u64, + + /// Config option to control maximum time local client can indicate typing. #[serde(default = "default_typing_client_timeout_max_s")] pub typing_client_timeout_max_s: u64, + /// Set this to true for conduwuit to compress HTTP response bodies using + /// zstd. This option does nothing if conduwuit was not built with + /// `zstd_compression` feature. Please be aware that enabling HTTP + /// compression may weaken TLS. Most users should not need to enable this. + /// See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH + /// before deciding to enable this. #[serde(default)] pub zstd_compression: bool, + + /// Set this to true for conduwuit to compress HTTP response bodies using + /// gzip. This option does nothing if conduwuit was not built with + /// `gzip_compression` feature. Please be aware that enabling HTTP + /// compression may weaken TLS. Most users should not need to enable this. + /// See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before + /// deciding to enable this. #[serde(default)] pub gzip_compression: bool, + + /// Set this to true for conduwuit to compress HTTP response bodies using + /// brotli. This option does nothing if conduwuit was not built with + /// `brotli_compression` feature. Please be aware that enabling HTTP + /// compression may weaken TLS. Most users should not need to enable this. + /// See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before + /// deciding to enable this. #[serde(default)] pub brotli_compression: bool, + /// Set to true to allow user type "guest" registrations. Element attempts + /// to register guest users automatically. Defaults to false #[serde(default)] pub allow_guest_registration: bool, + + /// Set to true to log guest registrations in the admin room. + /// Defaults to false as it may be noisy or unnecessary. #[serde(default)] pub log_guest_registrations: bool, + + /// Set to true to allow guest registrations/users to auto join any rooms + /// specified in `auto_join_rooms` Defaults to false #[serde(default)] pub allow_guests_auto_join_rooms: bool, + /// Config option to control whether the legacy unauthenticated Matrix media + /// repository endpoints will be enabled. These endpoints consist of: + /// - /_matrix/media/*/config + /// - /_matrix/media/*/upload + /// - /_matrix/media/*/preview_url + /// - /_matrix/media/*/download/* + /// - /_matrix/media/*/thumbnail/* + /// + /// The authenticated equivalent endpoints are always enabled. + /// + /// Defaults to true for now, but this is highly subject to change, likely + /// in the next release. #[serde(default = "true_fn")] pub allow_legacy_media: bool, + #[serde(default = "true_fn")] pub freeze_legacy_media: bool, + + /// Checks consistency of the media directory at startup: + /// 1. When `media_compat_file_link` is enbled, this check will upgrade + /// media when switching back and forth between Conduit and Conduwuit. + /// Both options must be enabled to handle this. + /// 2. When media is deleted from the directory, this check will also delete + /// its database entry. + /// + /// If none of these checks apply to your use cases, and your media + /// directory is significantly large setting this to false may reduce + /// startup time. + /// + /// Enabled by default. #[serde(default = "true_fn")] pub media_startup_check: bool, + + /// Enable backward-compatibility with Conduit's media directory by creating + /// symlinks of media. This option is only necessary if you plan on using + /// Conduit again. Otherwise setting this to false reduces filesystem + /// clutter and overhead for managing these symlinks in the directory. This + /// is now disabled by default. You may still return to upstream Conduit + /// but you have to run Conduwuit at least once with this set to true and + /// allow the media_startup_check to take place before shutting + /// down to return to Conduit. + /// + /// Disabled by default. #[serde(default)] pub media_compat_file_link: bool, + + /// Prunes missing media from the database as part of the media startup + /// checks. This means if you delete files from the media directory the + /// corresponding entries will be removed from the database. This is + /// disabled by default because if the media directory is accidentally moved + /// or inaccessible the metadata entries in the database will be lost with + /// sadness. + /// + /// Disabled by default. #[serde(default)] pub prune_missing_media: bool, + + /// Vector list of servers that conduwuit will refuse to download remote + /// media from. No default. #[serde(default = "Vec::new")] pub prevent_media_downloads_from: Vec, + /// List of forbidden server names that we will block incoming AND outgoing + /// federation with, and block client room joins / remote user invites. + /// + /// This check is applied on the room ID, room alias, sender server name, + /// sender user's server name, inbound federation X-Matrix origin, and + /// outbound federation handler. + /// + /// Basically "global" ACLs. No default. #[serde(default = "Vec::new")] pub forbidden_remote_server_names: Vec, + + /// List of forbidden server names that we will block all outgoing federated + /// room directory requests for. Useful for preventing our users from + /// wandering into bad servers or spaces. No default. #[serde(default = "Vec::new")] pub forbidden_remote_room_directory_server_names: Vec, + /// Vector list of IPv4 and IPv6 CIDR ranges / subnets *in quotes* that you + /// do not want conduwuit to send outbound requests to. Defaults to + /// RFC1918, unroutable, loopback, multicast, and testnet addresses for + /// security. + /// + /// To disable, set this to be an empty vector (`[]`). + /// Please be aware that this is *not* a guarantee. You should be using a + /// firewall with zones as doing this on the application layer may have + /// bypasses. + /// + /// Currently this does not account for proxies in use like Synapse does. #[serde(default = "default_ip_range_denylist")] pub ip_range_denylist: Vec, + /// Vector list of domains allowed to send requests to for URL previews. + /// Defaults to none. Note: this is a *contains* match, not an explicit + /// match. Putting "google.com" will match "https://google.com" and + /// "http://mymaliciousdomainexamplegoogle.com" Setting this to "*" will + /// allow all URL previews. Please note that this opens up significant + /// attack surface to your server, you are expected to be aware of the + /// risks by doing so. #[serde(default = "Vec::new")] pub url_preview_domain_contains_allowlist: Vec, + + /// Vector list of explicit domains allowed to send requests to for URL + /// previews. Defaults to none. Note: This is an *explicit* match, not a + /// contains match. Putting "google.com" will match "https://google.com", + /// "http://google.com", but not + /// "https://mymaliciousdomainexamplegoogle.com". Setting this to "*" will + /// allow all URL previews. Please note that this opens up significant + /// attack surface to your server, you are expected to be aware of the + /// risks by doing so. #[serde(default = "Vec::new")] pub url_preview_domain_explicit_allowlist: Vec, + + /// Vector list of explicit domains not allowed to send requests to for URL + /// previews. Defaults to none. Note: This is an *explicit* match, not a + /// contains match. Putting "google.com" will match "https://google.com", + /// "http://google.com", but not + /// "https://mymaliciousdomainexamplegoogle.com". The denylist is checked + /// first before allowlist. Setting this to "*" will not do anything. #[serde(default = "Vec::new")] pub url_preview_domain_explicit_denylist: Vec, + + /// Vector list of URLs allowed to send requests to for URL previews. + /// Defaults to none. Note that this is a *contains* match, not an + /// explicit match. Putting "google.com" will match + /// "https://google.com/", + /// "https://google.com/url?q=https://mymaliciousdomainexample.com", and + /// "https://mymaliciousdomainexample.com/hi/google.com" Setting this to + /// "*" will allow all URL previews. Please note that this opens up + /// significant attack surface to your server, you are expected to be + /// aware of the risks by doing so. #[serde(default = "Vec::new")] pub url_preview_url_contains_allowlist: Vec, + + /// Maximum amount of bytes allowed in a URL preview body size when + /// spidering. Defaults to 384KB (384_000 bytes) #[serde(default = "default_url_preview_max_spider_size")] pub url_preview_max_spider_size: usize, + + /// Option to decide whether you would like to run the domain allowlist + /// checks (contains and explicit) on the root domain or not. Does not apply + /// to URL contains allowlist. Defaults to false. Example: If this is + /// enabled and you have "wikipedia.org" allowed in the explicit and/or + /// contains domain allowlist, it will allow all subdomains under + /// "wikipedia.org" such as "en.m.wikipedia.org" as the root domain is + /// checked and matched. Useful if the domain contains allowlist is still + /// too broad for you but you still want to allow all the subdomains under a + /// root domain. #[serde(default)] pub url_preview_check_root_domain: bool, + /// List of forbidden room aliases and room IDs as patterns/strings. Values + /// in this list are matched as *contains*. This is checked upon room alias + /// creation, custom room ID creation if used, and startup as warnings if + /// any room aliases in your database have a forbidden room alias/ID. + /// No default. #[serde(default = "RegexSet::empty")] #[serde(with = "serde_regex")] pub forbidden_alias_names: RegexSet, + /// List of forbidden username patterns/strings. Values in this list are + /// matched as *contains*. This is checked upon username availability + /// check, registration, and startup as warnings if any local users in your + /// database have a forbidden username. + /// No default. #[serde(default = "RegexSet::empty")] #[serde(with = "serde_regex")] pub forbidden_usernames: RegexSet, + /// Retry failed and incomplete messages to remote servers immediately upon + /// startup. This is called bursting. If this is disabled, said messages + /// may not be delivered until more messages are queued for that server. Do + /// not change this option unless server resources are extremely limited or + /// the scale of the server's deployment is huge. Do not disable this + /// unless you know what you are doing. #[serde(default = "true_fn")] pub startup_netburst: bool, + + /// messages are dropped and not reattempted. The `startup_netburst` option + /// must be enabled for this value to have any effect. Do not change this + /// value unless you know what you are doing. Set this value to -1 to + /// reattempt every message without trimming the queues; this may consume + /// significant disk. Set this value to 0 to drop all messages without any + /// attempt at redelivery. #[serde(default = "default_startup_netburst_keep")] pub startup_netburst_keep: i64, + /// controls whether non-admin local users are forbidden from sending room + /// invites (local and remote), and if non-admin users can receive remote + /// room invites. admins are always allowed to send and receive all room + /// invites. defaults to false #[serde(default)] pub block_non_admin_invites: bool, + + /// Allows admins to enter commands in rooms other than #admins by prefixing + /// with \!admin. The reply will be publicly visible to the room, + /// originating from the sender. defaults to true #[serde(default = "true_fn")] pub admin_escape_commands: bool, + + /// Controls whether the conduwuit admin room console / CLI will immediately + /// activate on startup. This option can also be enabled with `--console` + /// conduwuit argument + /// + /// Defaults to false #[serde(default)] pub admin_console_automatic: bool, + + /// Controls what admin commands will be executed on startup. This is a + /// vector list of strings of admin commands to run. + /// + /// An example of this can be: `admin_execute = ["debug ping puppygock.gay", + /// "debug echo hi"]` + /// + /// This option can also be configured with the `--execute` conduwuit + /// argument and can take standard shell commands and environment variables + /// + /// Such example could be: `./conduwuit --execute "server admin-notice + /// conduwuit has started up at $(date)"` + /// + /// Defaults to nothing. #[serde(default)] pub admin_execute: Vec, + + /// Controls whether conduwuit should error and fail to start if an admin + /// execute command (`--execute` / `admin_execute`) fails + /// + /// Defaults to false #[serde(default)] pub admin_execute_errors_ignore: bool, + + /// Controls the max log level for admin command log captures (logs + /// generated from running admin commands) + /// + /// Defaults to "info" on release builds, else "debug" on debug builds #[serde(default = "default_admin_log_capture")] pub admin_log_capture: String, + #[serde(default = "default_admin_room_tag")] pub admin_room_tag: String, + /// Sentry.io crash/panic reporting, performance monitoring/metrics, etc. + /// This is NOT enabled by default. conduwuit's default Sentry reporting + /// endpoint is o4506996327251968.ingest.us.sentry.io + /// + /// Defaults to *false* #[serde(default)] pub sentry: bool, + + /// Sentry reporting URL if a custom one is desired + /// + /// Defaults to conduwuit's default Sentry endpoint: + /// "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" #[serde(default = "default_sentry_endpoint")] pub sentry_endpoint: Option, + + /// Report your Conduwuit server_name in Sentry.io crash reports and metrics + /// + /// Defaults to false #[serde(default)] pub sentry_send_server_name: bool, + + /// Performance monitoring/tracing sample rate for Sentry.io + /// + /// Note that too high values may impact performance, and can be disabled by + /// setting it to 0.0 (0%) This value is read as a percentage to Sentry, + /// represented as a decimal + /// + /// Defaults to 15% of traces (0.15) #[serde(default = "default_sentry_traces_sample_rate")] pub sentry_traces_sample_rate: f32, + + /// Whether to attach a stacktrace to Sentry reports. #[serde(default)] pub sentry_attach_stacktrace: bool, + + /// Send panics to sentry. This is true by default, but sentry has to be + /// enabled. #[serde(default = "true_fn")] pub sentry_send_panic: bool, + + /// Send errors to sentry. This is true by default, but sentry has to be + /// enabled. This option is only effective in release-mode; forced to false + /// in debug-mode. #[serde(default = "true_fn")] pub sentry_send_error: bool, + + /// Controls the tracing log level for Sentry to send things like + /// breadcrumbs and transactions Defaults to "info" #[serde(default = "default_sentry_filter")] pub sentry_filter: String, + /// Enable the tokio-console. This option is only relevant to developers. + /// See: docs/development.md#debugging-with-tokio-console for more + /// information. #[serde(default)] pub tokio_console: bool, From 2f24d7117a4f493bb90ea20b2b780486ef40272c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 6 Oct 2024 05:15:24 +0000 Subject: [PATCH 075/245] further develop serializer for insertions add JSON delegator to db serializer consolidate writes through memfun; simplifications Signed-off-by: Jason Volk --- src/database/mod.rs | 3 +- src/database/ser.rs | 216 ++++++++++++++++++++++----------------- src/database/tests.rs | 232 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 354 insertions(+), 97 deletions(-) create mode 100644 src/database/tests.rs diff --git a/src/database/mod.rs b/src/database/mod.rs index e66abf682..c39b2b2f2 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -10,6 +10,7 @@ pub mod maps; mod opts; mod ser; mod stream; +mod tests; mod util; mod watchers; @@ -28,7 +29,7 @@ pub use self::{ handle::Handle, keyval::{KeyVal, Slice}, map::Map, - ser::{Interfix, Separator}, + ser::{serialize, serialize_to_array, serialize_to_vec, Interfix, Json, Separator}, }; conduit::mod_ctor! {} diff --git a/src/database/ser.rs b/src/database/ser.rs index bd4bbd9ad..742f1e345 100644 --- a/src/database/ser.rs +++ b/src/database/ser.rs @@ -1,12 +1,24 @@ use std::io::Write; -use conduit::{err, result::DebugInspect, utils::exchange, Error, Result}; +use arrayvec::ArrayVec; +use conduit::{debug::type_name, err, result::DebugInspect, utils::exchange, Error, Result}; use serde::{ser, Serialize}; #[inline] -pub(crate) fn serialize_to_vec(val: &T) -> Result> +pub fn serialize_to_array(val: T) -> Result> where - T: Serialize + ?Sized, + T: Serialize, +{ + let mut buf = ArrayVec::::new(); + serialize(&mut buf, val)?; + + Ok(buf) +} + +#[inline] +pub fn serialize_to_vec(val: T) -> Result> +where + T: Serialize, { let mut buf = Vec::with_capacity(64); serialize(&mut buf, val)?; @@ -15,10 +27,10 @@ where } #[inline] -pub(crate) fn serialize<'a, W, T>(out: &'a mut W, val: &'a T) -> Result<&'a [u8]> +pub fn serialize<'a, W, T>(out: &'a mut W, val: T) -> Result<&'a [u8]> where - W: Write + AsRef<[u8]>, - T: Serialize + ?Sized, + W: Write + AsRef<[u8]> + 'a, + T: Serialize, { let mut serializer = Serializer { out, @@ -43,6 +55,10 @@ pub(crate) struct Serializer<'a, W: Write> { fin: bool, } +/// Newtype for JSON serialization. +#[derive(Debug, Serialize)] +pub struct Json(pub T); + /// Directive to force separator serialization specifically for prefix keying /// use. This is a quirk of the database schema and prefix iterations. #[derive(Debug, Serialize)] @@ -56,38 +72,43 @@ pub struct Separator; impl Serializer<'_, W> { const SEP: &'static [u8] = b"\xFF"; + fn tuple_start(&mut self) { + debug_assert!(!self.sep, "Tuple start with separator set"); + self.sequence_start(); + } + + fn tuple_end(&mut self) -> Result { + self.sequence_end()?; + Ok(()) + } + fn sequence_start(&mut self) { debug_assert!(!self.is_finalized(), "Sequence start with finalization set"); - debug_assert!(!self.sep, "Sequence start with separator set"); - if cfg!(debug_assertions) { - self.depth = self.depth.saturating_add(1); - } + cfg!(debug_assertions).then(|| self.depth = self.depth.saturating_add(1)); } - fn sequence_end(&mut self) { - self.sep = false; - if cfg!(debug_assertions) { - self.depth = self.depth.saturating_sub(1); - } + fn sequence_end(&mut self) -> Result { + cfg!(debug_assertions).then(|| self.depth = self.depth.saturating_sub(1)); + Ok(()) } - fn record_start(&mut self) -> Result<()> { + fn record_start(&mut self) -> Result { debug_assert!(!self.is_finalized(), "Starting a record after serialization finalized"); exchange(&mut self.sep, true) .then(|| self.separator()) .unwrap_or(Ok(())) } - fn separator(&mut self) -> Result<()> { + fn separator(&mut self) -> Result { debug_assert!(!self.is_finalized(), "Writing a separator after serialization finalized"); self.out.write_all(Self::SEP).map_err(Into::into) } + fn write(&mut self, buf: &[u8]) -> Result { self.out.write_all(buf).map_err(Into::into) } + fn set_finalized(&mut self) { debug_assert!(!self.is_finalized(), "Finalization already set"); - if cfg!(debug_assertions) { - self.fin = true; - } + cfg!(debug_assertions).then(|| self.fin = true); } fn is_finalized(&self) -> bool { self.fin } @@ -104,53 +125,65 @@ impl ser::Serializer for &mut Serializer<'_, W> { type SerializeTupleStruct = Self; type SerializeTupleVariant = Self; - fn serialize_map(self, _len: Option) -> Result { - unimplemented!("serialize Map not implemented") - } - fn serialize_seq(self, _len: Option) -> Result { self.sequence_start(); - self.record_start()?; Ok(self) } fn serialize_tuple(self, _len: usize) -> Result { - self.sequence_start(); + self.tuple_start(); Ok(self) } fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { - self.sequence_start(); + self.tuple_start(); Ok(self) } fn serialize_tuple_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, ) -> Result { - self.sequence_start(); - Ok(self) + unimplemented!("serialize Tuple Variant not implemented") + } + + fn serialize_map(self, _len: Option) -> Result { + unimplemented!( + "serialize Map not implemented; did you mean to use database::Json() around your serde_json::Value?" + ) } fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - self.sequence_start(); - Ok(self) + unimplemented!( + "serialize Struct not implemented at this time; did you mean to use database::Json() around your struct?" + ) } fn serialize_struct_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, ) -> Result { - self.sequence_start(); - Ok(self) + unimplemented!("serialize Struct Variant not implemented") } - fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result { - unimplemented!("serialize New Type Struct not implemented") + #[allow(clippy::needless_borrows_for_generic_args)] // buggy + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result + where + T: Serialize + ?Sized, + { + debug_assert!( + name != "Json" || type_name::() != "alloc::boxed::Box", + "serializing a Json(RawValue); you can skip serialization instead" + ); + + match name { + "Json" => serde_json::to_writer(&mut self.out, value).map_err(Into::into), + _ => unimplemented!("Unrecognized serialization Newtype {name:?}"), + } } fn serialize_newtype_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _value: &T, ) -> Result { - unimplemented!("serialize New Type Variant not implemented") + unimplemented!("serialize Newtype Variant not implemented") } fn serialize_unit_struct(self, name: &'static str) -> Result { @@ -180,136 +213,127 @@ impl ser::Serializer for &mut Serializer<'_, W> { self.serialize_str(v.encode_utf8(&mut buf)) } - fn serialize_str(self, v: &str) -> Result { self.serialize_bytes(v.as_bytes()) } + fn serialize_str(self, v: &str) -> Result { + debug_assert!( + self.depth > 0, + "serializing string at the top-level; you can skip serialization instead" + ); + + self.serialize_bytes(v.as_bytes()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + debug_assert!( + self.depth > 0, + "serializing byte array at the top-level; you can skip serialization instead" + ); - fn serialize_bytes(self, v: &[u8]) -> Result { self.out.write_all(v).map_err(Error::Io) } + self.write(v) + } fn serialize_f64(self, _v: f64) -> Result { unimplemented!("serialize f64 not implemented") } fn serialize_f32(self, _v: f32) -> Result { unimplemented!("serialize f32 not implemented") } - fn serialize_i64(self, v: i64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + fn serialize_i64(self, v: i64) -> Result { self.write(&v.to_be_bytes()) } - fn serialize_i32(self, _v: i32) -> Result { unimplemented!("serialize i32 not implemented") } + fn serialize_i32(self, v: i32) -> Result { self.write(&v.to_be_bytes()) } fn serialize_i16(self, _v: i16) -> Result { unimplemented!("serialize i16 not implemented") } fn serialize_i8(self, _v: i8) -> Result { unimplemented!("serialize i8 not implemented") } - fn serialize_u64(self, v: u64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + fn serialize_u64(self, v: u64) -> Result { self.write(&v.to_be_bytes()) } - fn serialize_u32(self, _v: u32) -> Result { unimplemented!("serialize u32 not implemented") } + fn serialize_u32(self, v: u32) -> Result { self.write(&v.to_be_bytes()) } fn serialize_u16(self, _v: u16) -> Result { unimplemented!("serialize u16 not implemented") } - fn serialize_u8(self, v: u8) -> Result { self.out.write_all(&[v]).map_err(Error::Io) } + fn serialize_u8(self, v: u8) -> Result { self.write(&[v]) } fn serialize_bool(self, _v: bool) -> Result { unimplemented!("serialize bool not implemented") } fn serialize_unit(self) -> Result { unimplemented!("serialize unit not implemented") } } -impl ser::SerializeMap for &mut Serializer<'_, W> { +impl ser::SerializeSeq for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_key(&mut self, _key: &T) -> Result { - unimplemented!("serialize Map Key not implemented") - } - - fn serialize_value(&mut self, _val: &T) -> Result { - unimplemented!("serialize Map Val not implemented") - } + fn serialize_element(&mut self, val: &T) -> Result { val.serialize(&mut **self) } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { self.sequence_end() } } -impl ser::SerializeSeq for &mut Serializer<'_, W> { +impl ser::SerializeTuple for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_element(&mut self, val: &T) -> Result { val.serialize(&mut **self) } - - fn end(self) -> Result { - self.sequence_end(); - Ok(()) + fn serialize_element(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) } + + fn end(self) -> Result { self.tuple_end() } } -impl ser::SerializeStruct for &mut Serializer<'_, W> { +impl ser::SerializeTupleStruct for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + fn serialize_field(&mut self, val: &T) -> Result { self.record_start()?; val.serialize(&mut **self) } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { self.tuple_end() } } -impl ser::SerializeStructVariant for &mut Serializer<'_, W> { +impl ser::SerializeTupleVariant for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + fn serialize_field(&mut self, val: &T) -> Result { self.record_start()?; val.serialize(&mut **self) } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { self.tuple_end() } } -impl ser::SerializeTuple for &mut Serializer<'_, W> { +impl ser::SerializeMap for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_element(&mut self, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) + fn serialize_key(&mut self, _key: &T) -> Result { + unimplemented!("serialize Map Key not implemented") } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) + fn serialize_value(&mut self, _val: &T) -> Result { + unimplemented!("serialize Map Val not implemented") } + + fn end(self) -> Result { unimplemented!("serialize Map End not implemented") } } -impl ser::SerializeTupleStruct for &mut Serializer<'_, W> { +impl ser::SerializeStruct for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_field(&mut self, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) + fn serialize_field(&mut self, _key: &'static str, _val: &T) -> Result { + unimplemented!("serialize Struct Field not implemented") } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { unimplemented!("serialize Struct End not implemented") } } -impl ser::SerializeTupleVariant for &mut Serializer<'_, W> { +impl ser::SerializeStructVariant for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_field(&mut self, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) + fn serialize_field(&mut self, _key: &'static str, _val: &T) -> Result { + unimplemented!("serialize Struct Variant Field not implemented") } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { unimplemented!("serialize Struct Variant End not implemented") } } diff --git a/src/database/tests.rs b/src/database/tests.rs new file mode 100644 index 000000000..47dfb32c3 --- /dev/null +++ b/src/database/tests.rs @@ -0,0 +1,232 @@ +#![cfg(test)] +#![allow(clippy::needless_borrows_for_generic_args)] + +use std::fmt::Debug; + +use arrayvec::ArrayVec; +use conduit::ruma::{serde::Raw, RoomId, UserId}; +use serde::Serialize; + +use crate::{ + de, ser, + ser::{serialize_to_vec, Json}, + Interfix, +}; + +#[test] +#[should_panic(expected = "serializing string at the top-level")] +fn ser_str() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let s = serialize_to_vec(&user_id).expect("failed to serialize user_id"); + assert_eq!(&s, user_id.as_bytes()); +} + +#[test] +fn ser_tuple() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let mut a = user_id.as_bytes().to_vec(); + a.push(0xFF); + a.extend_from_slice(room_id.as_bytes()); + + let b = (user_id, room_id); + let b = serialize_to_vec(&b).expect("failed to serialize tuple"); + + assert_eq!(a, b); +} + +#[test] +#[should_panic(expected = "I/O error: failed to write whole buffer")] +fn ser_overflow() { + const BUFSIZE: usize = 10; + + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + assert!(BUFSIZE < user_id.as_str().len() + room_id.as_str().len()); + let mut buf = ArrayVec::::new(); + + let val = (user_id, room_id); + _ = ser::serialize(&mut buf, val).unwrap(); +} + +#[test] +fn ser_complex() { + use conduit::ruma::Mxc; + + #[derive(Debug, Serialize)] + struct Dim { + width: u32, + height: u32, + } + + let mxc = Mxc { + server_name: "example.com".try_into().unwrap(), + media_id: "AbCdEfGhIjK", + }; + + let dim = Dim { + width: 123, + height: 456, + }; + + let mut a = Vec::new(); + a.extend_from_slice(b"mxc://"); + a.extend_from_slice(mxc.server_name.as_bytes()); + a.extend_from_slice(b"/"); + a.extend_from_slice(mxc.media_id.as_bytes()); + a.push(0xFF); + a.extend_from_slice(&dim.width.to_be_bytes()); + a.extend_from_slice(&dim.height.to_be_bytes()); + a.push(0xFF); + + let d: &[u32] = &[dim.width, dim.height]; + let b = (mxc, d, Interfix); + let b = serialize_to_vec(b).expect("failed to serialize complex"); + + assert_eq!(a, b); +} + +#[test] +fn ser_json() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let serialized = serialize_to_vec(Json(&filter)).expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn ser_json_value() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::to_value(filter).expect("failed to serialize to serde_json::value"); + let serialized = serialize_to_vec(Json(value)).expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn ser_json_macro() { + use serde_json::json; + + #[derive(Serialize)] + struct Foo { + foo: String, + } + + let content = Foo { + foo: "bar".to_owned(), + }; + let content = serde_json::to_value(content).expect("failed to serialize content"); + let sender: &UserId = "@foo:example.com".try_into().unwrap(); + let serialized = serialize_to_vec(Json(json!({ + "sender": sender, + "content": content, + }))) + .expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"content":{"foo":"bar"},"sender":"@foo:example.com"}"#); +} + +#[test] +#[should_panic(expected = "serializing string at the top-level")] +fn ser_json_raw() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::value::to_raw_value(&filter).expect("failed to serialize to raw value"); + let a = serialize_to_vec(value.get()).expect("failed to serialize raw value"); + let s = String::from_utf8_lossy(&a); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +#[should_panic(expected = "you can skip serialization instead")] +fn ser_json_raw_json() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::value::to_raw_value(&filter).expect("failed to serialize to raw value"); + let a = serialize_to_vec(Json(value)).expect("failed to serialize json value"); + let s = String::from_utf8_lossy(&a); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn de_tuple() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF!room:example.com"; + let (a, b): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, room_id, "deserialized room_id does not match"); +} + +#[test] +fn de_json_array() { + let a = &["foo", "bar", "baz"]; + let s = serde_json::to_vec(a).expect("failed to serialize to JSON array"); + + let b: Raw>> = de::from_slice(&s).expect("failed to deserialize"); + + let d: Vec = serde_json::from_str(b.json().get()).expect("failed to deserialize JSON"); + + for (i, a) in a.iter().enumerate() { + assert_eq!(*a, d[i]); + } +} + +#[test] +fn de_json_raw_array() { + let a = &["foo", "bar", "baz"]; + let s = serde_json::to_vec(a).expect("failed to serialize to JSON array"); + + let b: Raw>> = de::from_slice(&s).expect("failed to deserialize"); + + let c: Vec> = serde_json::from_str(b.json().get()).expect("failed to deserialize JSON"); + + for (i, a) in a.iter().enumerate() { + let c = serde_json::to_value(c[i].json()).expect("failed to deserialize JSON to string"); + assert_eq!(*a, c); + } +} + +#[test] +fn ser_array() { + let a: u64 = 123_456; + let b: u64 = 987_654; + + let arr: &[u64] = &[a, b]; + + let mut v = Vec::new(); + v.extend_from_slice(&a.to_be_bytes()); + v.extend_from_slice(&b.to_be_bytes()); + + let s = serialize_to_vec(arr).expect("failed to serialize"); + assert_eq!(&s, &v, "serialization does not match"); +} From d3d11356ee59858dcf26fa66ad2b6c9c4ac13a61 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 6 Oct 2024 22:15:19 +0000 Subject: [PATCH 076/245] add serialized insert interface Signed-off-by: Jason Volk --- src/database/map/insert.rs | 187 +++++++++++++++++++++++++++++++++++-- src/database/map/remove.rs | 12 +-- 2 files changed, 186 insertions(+), 13 deletions(-) diff --git a/src/database/map/insert.rs b/src/database/map/insert.rs index 953c9c94c..39a0c422e 100644 --- a/src/database/map/insert.rs +++ b/src/database/map/insert.rs @@ -1,21 +1,194 @@ -use std::{convert::AsRef, fmt::Debug}; +//! Insert a Key+Value into the database. +//! +//! Overloads are provided for the user to choose the most efficient +//! serialization or bypass for pre=serialized (raw) inputs. +use std::{convert::AsRef, fmt::Debug, io::Write}; + +use arrayvec::ArrayVec; use conduit::implement; use rocksdb::WriteBatchWithTransaction; +use serde::Serialize; + +use crate::{ser, util::or_else}; + +/// Insert Key/Value +/// +/// - Key is serialized +/// - Val is serialized +#[implement(super::Map)] +pub fn put(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = Vec::new(); + let mut val_buf = Vec::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized +/// - Val is raw +#[implement(super::Map)] +pub fn put_raw(&self, key: K, val: V) +where + K: Serialize + Debug, + V: AsRef<[u8]>, +{ + let mut key_buf = Vec::new(); + self.bput_raw(key, val, &mut key_buf); +} + +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is serialized +#[implement(super::Map)] +pub fn raw_put(&self, key: K, val: V) +where + K: AsRef<[u8]>, + V: Serialize, +{ + let mut val_buf = Vec::new(); + self.raw_bput(key, val, &mut val_buf); +} + +/// Insert Key/Value +/// +/// - Key is serialized +/// - Val is serialized to stack-buffer +#[implement(super::Map)] +pub fn put_aput(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = Vec::new(); + let mut val_buf = ArrayVec::::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized to stack-buffer +/// - Val is serialized +#[implement(super::Map)] +pub fn aput_put(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = ArrayVec::::new(); + let mut val_buf = Vec::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized to stack-buffer +/// - Val is serialized to stack-buffer +#[implement(super::Map)] +pub fn aput(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = ArrayVec::::new(); + let mut val_buf = ArrayVec::::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized to stack-buffer +/// - Val is raw +#[implement(super::Map)] +pub fn aput_raw(&self, key: K, val: V) +where + K: Serialize + Debug, + V: AsRef<[u8]>, +{ + let mut key_buf = ArrayVec::::new(); + self.bput_raw(key, val, &mut key_buf); +} + +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is serialized to stack-buffer +#[implement(super::Map)] +pub fn raw_aput(&self, key: K, val: V) +where + K: AsRef<[u8]>, + V: Serialize, +{ + let mut val_buf = ArrayVec::::new(); + self.raw_bput(key, val, &mut val_buf); +} -use crate::util::or_else; +/// Insert Key/Value +/// +/// - Key is serialized to supplied buffer +/// - Val is serialized to supplied buffer +#[implement(super::Map)] +pub fn bput(&self, key: K, val: V, mut buf: (Bk, Bv)) +where + K: Serialize + Debug, + V: Serialize, + Bk: Write + AsRef<[u8]>, + Bv: Write + AsRef<[u8]>, +{ + let val = ser::serialize(&mut buf.1, val).expect("failed to serialize insertion val"); + self.bput_raw(key, val, &mut buf.0); +} + +/// Insert Key/Value +/// +/// - Key is serialized to supplied buffer +/// - Val is raw +#[implement(super::Map)] +pub fn bput_raw(&self, key: K, val: V, mut buf: Bk) +where + K: Serialize + Debug, + V: AsRef<[u8]>, + Bk: Write + AsRef<[u8]>, +{ + let key = ser::serialize(&mut buf, key).expect("failed to serialize insertion key"); + self.insert(&key, val); +} + +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is serialized to supplied buffer +#[implement(super::Map)] +pub fn raw_bput(&self, key: K, val: V, mut buf: Bv) +where + K: AsRef<[u8]>, + V: Serialize, + Bv: Write + AsRef<[u8]>, +{ + let val = ser::serialize(&mut buf, val).expect("failed to serialize insertion val"); + self.insert(&key, val); +} +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is raw #[implement(super::Map)] -#[tracing::instrument(skip(self, value), fields(%self), level = "trace")] -pub fn insert(&self, key: &K, value: &V) +#[tracing::instrument(skip_all, fields(%self), level = "trace")] +pub fn insert(&self, key: &K, val: V) where - K: AsRef<[u8]> + ?Sized + Debug, - V: AsRef<[u8]> + ?Sized, + K: AsRef<[u8]> + ?Sized, + V: AsRef<[u8]>, { let write_options = &self.write_options; self.db .db - .put_cf_opt(&self.cf(), key, value, write_options) + .put_cf_opt(&self.cf(), key, val, write_options) .or_else(or_else) .expect("database insert error"); diff --git a/src/database/map/remove.rs b/src/database/map/remove.rs index 10bb2ff01..42eaa477d 100644 --- a/src/database/map/remove.rs +++ b/src/database/map/remove.rs @@ -7,18 +7,18 @@ use serde::Serialize; use crate::{ser, util::or_else}; #[implement(super::Map)] -pub fn del(&self, key: &K) +pub fn del(&self, key: K) where - K: Serialize + ?Sized + Debug, + K: Serialize + Debug, { let mut buf = Vec::::with_capacity(64); self.bdel(key, &mut buf); } #[implement(super::Map)] -pub fn adel(&self, key: &K) +pub fn adel(&self, key: K) where - K: Serialize + ?Sized + Debug, + K: Serialize + Debug, { let mut buf = ArrayVec::::new(); self.bdel(key, &mut buf); @@ -26,9 +26,9 @@ where #[implement(super::Map)] #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] -pub fn bdel(&self, key: &K, buf: &mut B) +pub fn bdel(&self, key: K, buf: &mut B) where - K: Serialize + ?Sized + Debug, + K: Serialize + Debug, B: Write + AsRef<[u8]>, { let key = ser::serialize(buf, key).expect("failed to serialize deletion key"); From 19880ce12bf3bf79bcfa8cb21223de48ab268686 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 9 Oct 2024 19:41:52 +0000 Subject: [PATCH 077/245] add IgnoreAll directive to deserializer Signed-off-by: Jason Volk --- src/database/de.rs | 27 +++++++++++++++++++++++++-- src/database/mod.rs | 2 +- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/database/de.rs b/src/database/de.rs index 9ee52267e..e5fdf7cb2 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -5,6 +5,7 @@ use serde::{ Deserialize, }; +/// Deserialize into T from buffer. pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result where T: Deserialize<'a>, @@ -22,6 +23,7 @@ where }) } +/// Deserialization state. pub(crate) struct Deserializer<'de> { buf: &'de [u8], pos: usize, @@ -33,6 +35,11 @@ pub(crate) struct Deserializer<'de> { #[derive(Debug, Deserialize)] pub struct Ignore; +/// Directive to ignore all remaining records. This can be used in a sequence to +/// ignore the rest of the sequence. +#[derive(Debug, Deserialize)] +pub struct IgnoreAll; + impl<'de> Deserializer<'de> { /// Record separator; an intentionally invalid-utf8 byte. const SEP: u8 = b'\xFF'; @@ -53,6 +60,13 @@ impl<'de> Deserializer<'de> { ))) } + /// Called at the start of arrays and tuples + #[inline] + fn sequence_start(&mut self) { + debug_assert!(!self.seq, "Nested sequences are not handled at this time"); + self.seq = true; + } + /// Consume the current record to ignore it. Inside a sequence the next /// record is skipped but at the top-level all records are skipped such that /// deserialization completes with self.finished() == Ok. @@ -61,10 +75,16 @@ impl<'de> Deserializer<'de> { if self.seq { self.record_next(); } else { - self.record_trail(); + self.record_ignore_all(); } } + /// Consume the current and all remaining records to ignore them. Similar to + /// Ignore at the top-level, but it can be provided in a sequence to Ignore + /// all remaining elements. + #[inline] + fn record_ignore_all(&mut self) { self.record_trail(); } + /// Consume the current record. The position pointer is moved to the start /// of the next record. Slice of the current record is returned. #[inline] @@ -101,7 +121,6 @@ impl<'de> Deserializer<'de> { ); self.inc_pos(started.into()); - self.seq = true; } /// Consume all remaining bytes, which may include record separators, @@ -128,6 +147,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { + self.sequence_start(); visitor.visit_seq(self) } @@ -135,6 +155,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { + self.sequence_start(); visitor.visit_seq(self) } @@ -142,6 +163,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { + self.sequence_start(); visitor.visit_seq(self) } @@ -170,6 +192,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { match name { "Ignore" => self.record_ignore(), + "IgnoreAll" => self.record_ignore_all(), _ => unimplemented!("Unrecognized deserialization Directive {name:?}"), }; diff --git a/src/database/mod.rs b/src/database/mod.rs index c39b2b2f2..6d3b2079b 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -24,7 +24,7 @@ extern crate rust_rocksdb as rocksdb; pub use self::{ database::Database, - de::Ignore, + de::{Ignore, IgnoreAll}, deserialized::Deserialized, handle::Handle, keyval::{KeyVal, Slice}, From 8258d16a94855dae3df68ae8dcdea1bde0601f4e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 9 Oct 2024 05:08:22 +0000 Subject: [PATCH 078/245] re-scheme naming of stream iterator overloads Signed-off-by: Jason Volk --- src/database/map/count.rs | 8 +++----- src/database/map/keys_from.rs | 6 +++--- src/database/map/keys_prefix.rs | 6 +++--- src/database/map/rev_keys_from.rs | 6 +++--- src/database/map/rev_keys_prefix.rs | 6 +++--- src/database/map/rev_stream_from.rs | 6 +++--- src/database/map/rev_stream_prefix.rs | 6 +++--- src/database/map/stream_from.rs | 6 +++--- src/database/map/stream_prefix.rs | 6 +++--- src/service/key_backups/mod.rs | 10 +++++----- src/service/rooms/alias/mod.rs | 2 +- src/service/rooms/lazy_loading/mod.rs | 2 +- src/service/rooms/metadata/mod.rs | 2 +- src/service/rooms/read_receipt/data.rs | 2 +- src/service/rooms/state/data.rs | 2 +- src/service/rooms/state_cache/mod.rs | 4 ++-- src/service/sending/data.rs | 4 ++-- src/service/users/mod.rs | 2 +- 18 files changed, 42 insertions(+), 44 deletions(-) diff --git a/src/database/map/count.rs b/src/database/map/count.rs index 4356b71f5..dab45b7a9 100644 --- a/src/database/map/count.rs +++ b/src/database/map/count.rs @@ -4,12 +4,10 @@ use conduit::implement; use futures::stream::StreamExt; use serde::Serialize; -use crate::de::Ignore; - /// Count the total number of entries in the map. #[implement(super::Map)] #[inline] -pub fn count(&self) -> impl Future + Send + '_ { self.keys::().count() } +pub fn count(&self) -> impl Future + Send + '_ { self.raw_keys().count() } /// Count the number of entries in the map starting from a lower-bound. /// @@ -20,7 +18,7 @@ pub fn count_from<'a, P>(&'a self, from: &P) -> impl Future + Se where P: Serialize + ?Sized + Debug + 'a, { - self.keys_from::(from).count() + self.keys_from_raw(from).count() } /// Count the number of entries in the map matching a prefix. @@ -32,5 +30,5 @@ pub fn count_prefix<'a, P>(&'a self, prefix: &P) -> impl Future where P: Serialize + ?Sized + Debug + 'a, { - self.keys_prefix::(prefix).count() + self.keys_prefix_raw(prefix).count() } diff --git a/src/database/map/keys_from.rs b/src/database/map/keys_from.rs index 1993750ab..4eb3b12e5 100644 --- a/src/database/map/keys_from.rs +++ b/src/database/map/keys_from.rs @@ -13,13 +13,13 @@ where P: Serialize + ?Sized + Debug, K: Deserialize<'a> + Send, { - self.keys_raw_from(from) + self.keys_from_raw(from) .map(keyval::result_deserialize_key::) } #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +pub fn keys_from_raw

    (&self, from: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -29,7 +29,7 @@ where #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +pub fn keys_raw_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send where P: AsRef<[u8]> + ?Sized + Debug + Sync, K: Deserialize<'a> + Send, diff --git a/src/database/map/keys_prefix.rs b/src/database/map/keys_prefix.rs index d6c0927b9..0ff755f35 100644 --- a/src/database/map/keys_prefix.rs +++ b/src/database/map/keys_prefix.rs @@ -17,13 +17,13 @@ where P: Serialize + ?Sized + Debug, K: Deserialize<'a> + Send, { - self.keys_raw_prefix(prefix) + self.keys_prefix_raw(prefix) .map(keyval::result_deserialize_key::) } #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +pub fn keys_prefix_raw

    (&self, prefix: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -34,7 +34,7 @@ where #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +pub fn keys_raw_prefix<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a where P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, K: Deserialize<'a> + Send + 'a, diff --git a/src/database/map/rev_keys_from.rs b/src/database/map/rev_keys_from.rs index e012e60af..b142718ce 100644 --- a/src/database/map/rev_keys_from.rs +++ b/src/database/map/rev_keys_from.rs @@ -13,13 +13,13 @@ where P: Serialize + ?Sized + Debug, K: Deserialize<'a> + Send, { - self.rev_keys_raw_from(from) + self.rev_keys_from_raw(from) .map(keyval::result_deserialize_key::) } #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +pub fn rev_keys_from_raw

    (&self, from: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -29,7 +29,7 @@ where #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +pub fn rev_keys_raw_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send where P: AsRef<[u8]> + ?Sized + Debug + Sync, K: Deserialize<'a> + Send, diff --git a/src/database/map/rev_keys_prefix.rs b/src/database/map/rev_keys_prefix.rs index 162c4f9b8..5297cecf9 100644 --- a/src/database/map/rev_keys_prefix.rs +++ b/src/database/map/rev_keys_prefix.rs @@ -17,13 +17,13 @@ where P: Serialize + ?Sized + Debug, K: Deserialize<'a> + Send, { - self.rev_keys_raw_prefix(prefix) + self.rev_keys_prefix_raw(prefix) .map(keyval::result_deserialize_key::) } #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +pub fn rev_keys_prefix_raw

    (&self, prefix: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -34,7 +34,7 @@ where #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +pub fn rev_keys_raw_prefix<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a where P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, K: Deserialize<'a> + Send + 'a, diff --git a/src/database/map/rev_stream_from.rs b/src/database/map/rev_stream_from.rs index c48f406b2..78318a7fe 100644 --- a/src/database/map/rev_stream_from.rs +++ b/src/database/map/rev_stream_from.rs @@ -18,7 +18,7 @@ where K: Deserialize<'a> + Send, V: Deserialize<'a> + Send, { - self.rev_stream_raw_from(from) + self.rev_stream_from_raw(from) .map(keyval::result_deserialize::) } @@ -28,7 +28,7 @@ where /// - Result is raw #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +pub fn rev_stream_from_raw

    (&self, from: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -42,7 +42,7 @@ where /// - Result is deserialized #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +pub fn rev_stream_raw_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send where P: AsRef<[u8]> + ?Sized + Debug + Sync, K: Deserialize<'a> + Send, diff --git a/src/database/map/rev_stream_prefix.rs b/src/database/map/rev_stream_prefix.rs index 9ef89e9cb..601c3298c 100644 --- a/src/database/map/rev_stream_prefix.rs +++ b/src/database/map/rev_stream_prefix.rs @@ -22,7 +22,7 @@ where K: Deserialize<'a> + Send, V: Deserialize<'a> + Send, { - self.rev_stream_raw_prefix(prefix) + self.rev_stream_prefix_raw(prefix) .map(keyval::result_deserialize::) } @@ -32,7 +32,7 @@ where /// - Result is raw #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +pub fn rev_stream_prefix_raw

    (&self, prefix: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -47,7 +47,7 @@ where /// - Result is deserialized #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn rev_stream_prefix_raw<'a, K, V, P>( +pub fn rev_stream_raw_prefix<'a, K, V, P>( &'a self, prefix: &'a P, ) -> impl Stream>> + Send + 'a where diff --git a/src/database/map/stream_from.rs b/src/database/map/stream_from.rs index db8281250..0d3bb1e10 100644 --- a/src/database/map/stream_from.rs +++ b/src/database/map/stream_from.rs @@ -18,7 +18,7 @@ where K: Deserialize<'a> + Send, V: Deserialize<'a> + Send, { - self.stream_raw_from(from) + self.stream_from_raw(from) .map(keyval::result_deserialize::) } @@ -28,7 +28,7 @@ where /// - Result is raw #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +pub fn stream_from_raw

    (&self, from: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -42,7 +42,7 @@ where /// - Result is deserialized #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +pub fn stream_raw_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send where P: AsRef<[u8]> + ?Sized + Debug + Sync, K: Deserialize<'a> + Send, diff --git a/src/database/map/stream_prefix.rs b/src/database/map/stream_prefix.rs index 56154a8b3..cab3dd098 100644 --- a/src/database/map/stream_prefix.rs +++ b/src/database/map/stream_prefix.rs @@ -22,7 +22,7 @@ where K: Deserialize<'a> + Send, V: Deserialize<'a> + Send, { - self.stream_raw_prefix(prefix) + self.stream_prefix_raw(prefix) .map(keyval::result_deserialize::) } @@ -32,7 +32,7 @@ where /// - Result is raw #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +pub fn stream_prefix_raw

    (&self, prefix: &P) -> impl Stream>> + Send where P: Serialize + ?Sized + Debug, { @@ -47,7 +47,7 @@ where /// - Result is deserialized #[implement(super::Map)] #[tracing::instrument(skip(self), fields(%self), level = "trace")] -pub fn stream_prefix_raw<'a, K, V, P>( +pub fn stream_raw_prefix<'a, K, V, P>( &'a self, prefix: &'a P, ) -> impl Stream>> + Send + 'a where diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 55263eeb1..4c3037571 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -79,7 +79,7 @@ pub async fn delete_backup(&self, user_id: &UserId, version: &str) { let key = (user_id, version, Interfix); self.db .backupkeyid_backup - .keys_raw_prefix(&key) + .keys_prefix_raw(&key) .ignore_err() .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) .await; @@ -181,7 +181,7 @@ pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize { let prefix = (user_id, version); self.db .backupkeyid_backup - .keys_raw_prefix(&prefix) + .keys_prefix_raw(&prefix) .count() .await } @@ -256,7 +256,7 @@ pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) { let key = (user_id, version, Interfix); self.db .backupkeyid_backup - .keys_raw_prefix(&key) + .keys_prefix_raw(&key) .ignore_err() .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) .await; @@ -267,7 +267,7 @@ pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: & let key = (user_id, version, room_id, Interfix); self.db .backupkeyid_backup - .keys_raw_prefix(&key) + .keys_prefix_raw(&key) .ignore_err() .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) .await; @@ -278,7 +278,7 @@ pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &R let key = (user_id, version, room_id, session_id); self.db .backupkeyid_backup - .keys_raw_prefix(&key) + .keys_prefix_raw(&key) .ignore_err() .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) .await; diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 7fac6be69..3f944729e 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -101,7 +101,7 @@ impl Service { let prefix = (&room_id, Interfix); self.db .aliasid_alias - .keys_raw_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() .ready_for_each(|key| self.db.aliasid_alias.remove(key)) .await; diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index e0816d3f3..9493dcc49 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -99,7 +99,7 @@ pub async fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room let prefix = (user_id, device_id, room_id, Interfix); self.db .lazyloadedids - .keys_raw_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() .ready_for_each(|key| self.db.lazyloadedids.remove(key)) .await; diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index d8be6aab6..8367eb72d 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -50,7 +50,7 @@ pub async fn exists(&self, room_id: &RoomId) -> bool { // Look for PDUs in that room. self.db .pduid_pdu - .keys_raw_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() .next() .await diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index a2c0fabca..74b649ef3 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -84,7 +84,7 @@ impl Data { let prefix2 = prefix.clone(); self.readreceiptid_readreceipt - .stream_raw_from(&first_possible_edu) + .stream_from_raw(&first_possible_edu) .ignore_err() .ready_take_while(move |(k, _)| k.starts_with(&prefix2)) .map(move |(k, v)| { diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 3072e3c65..7265038fd 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -52,7 +52,7 @@ impl Data { ) { let prefix = (room_id, Interfix); self.roomid_pduleaves - .keys_raw_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() .ready_for_each(|key| self.roomid_pduleaves.remove(key)) .await; diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 8539c9402..edfae5291 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -408,7 +408,7 @@ impl Service { pub fn rooms_joined<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { self.db .userroomid_joined - .keys_prefix_raw(user_id) + .keys_raw_prefix(user_id) .ignore_err() .map(|(_, room_id): (Ignore, &RoomId)| room_id) } @@ -469,7 +469,7 @@ impl Service { self.db .roomid_inviteviaservers - .stream_prefix_raw(room_id) + .stream_raw_prefix(room_id) .ignore_err() .map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server")) } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 6f4b5b970..96d4a6a91 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -101,7 +101,7 @@ impl Data { pub fn active_requests_for(&self, destination: &Destination) -> impl Stream + Send + '_ { let prefix = destination.get_prefix(); self.servercurrentevent_data - .stream_raw_prefix(&prefix) + .stream_prefix_raw(&prefix) .ignore_err() .map(|(key, val)| { let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); @@ -136,7 +136,7 @@ impl Data { pub fn queued_requests(&self, destination: &Destination) -> impl Stream + Send + '_ { let prefix = destination.get_prefix(); self.servernameevent_data - .stream_raw_prefix(&prefix) + .stream_prefix_raw(&prefix) .ignore_err() .map(|(key, val)| { let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 71a93666f..a99a7df4b 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -311,7 +311,7 @@ impl Service { let prefix = (user_id, device_id, Interfix); self.db .todeviceid_events - .keys_raw_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() .ready_for_each(|key| self.db.todeviceid_events.remove(key)) .await; From 2ed0c267eb698c33befc4daa482811f0ae45707a Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 7 Oct 2024 17:54:27 +0000 Subject: [PATCH 079/245] Refactor for structured insertions Signed-off-by: Jason Volk --- Cargo.lock | 26 +-- Cargo.toml | 2 +- src/api/client/sync.rs | 5 +- src/api/server/invite.rs | 6 - src/service/account_data/mod.rs | 40 +--- src/service/globals/data.rs | 32 ++- src/service/globals/migrations.rs | 17 +- src/service/key_backups/mod.rs | 62 ++---- src/service/media/data.rs | 103 +++------ src/service/media/migrations.rs | 2 +- src/service/presence/data.rs | 18 +- src/service/presence/presence.rs | 4 - src/service/pusher/mod.rs | 16 +- src/service/rooms/directory/mod.rs | 4 +- src/service/rooms/lazy_loading/mod.rs | 12 +- src/service/rooms/metadata/mod.rs | 8 +- src/service/rooms/outlier/mod.rs | 7 +- src/service/rooms/pdu_metadata/data.rs | 16 +- src/service/rooms/read_receipt/data.rs | 40 +--- src/service/rooms/short/mod.rs | 47 ++-- src/service/rooms/state/data.rs | 13 +- src/service/rooms/state_accessor/mod.rs | 6 + src/service/rooms/state_cache/data.rs | 73 +++---- src/service/rooms/state_cache/mod.rs | 89 ++++---- src/service/rooms/timeline/data.rs | 18 +- src/service/rooms/user/data.rs | 26 +-- src/service/sending/data.rs | 3 +- src/service/sending/dest.rs | 2 +- src/service/uiaa/mod.rs | 17 +- src/service/updates/mod.rs | 5 +- src/service/users/mod.rs | 274 +++++++++--------------- 31 files changed, 368 insertions(+), 625 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cae6994c6..db1394ce6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2976,7 +2976,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "assign", "js_int", @@ -2998,7 +2998,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "js_int", "ruma-common", @@ -3010,7 +3010,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "as_variant", "assign", @@ -3033,7 +3033,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "as_variant", "base64 0.22.1", @@ -3063,7 +3063,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3087,7 +3087,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "bytes", "http", @@ -3105,7 +3105,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "js_int", "thiserror", @@ -3114,7 +3114,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "js_int", "ruma-common", @@ -3124,7 +3124,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "cfg-if", "once_cell", @@ -3140,7 +3140,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "js_int", "ruma-common", @@ -3152,7 +3152,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "headers", "http", @@ -3165,7 +3165,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3181,7 +3181,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=f485a0265c67a59df75fc6686787538172fa4cac#f485a0265c67a59df75fc6686787538172fa4cac" +source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" dependencies = [ "futures-util", "itertools 0.13.0", diff --git a/Cargo.toml b/Cargo.toml index 25d1001da..0a98befd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -315,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "f485a0265c67a59df75fc6686787538172fa4cac" +rev = "3109496a1f91357c89cbb57cf86f179e2cb013e7" features = [ "compat", "rand", diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 8c4c6a445..65af775d4 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -7,7 +7,7 @@ use std::{ use axum::extract::State; use conduit::{ debug, err, error, is_equal_to, - result::{FlatOk, IntoIsOk}, + result::FlatOk, utils::{ math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, BoolExt, IterStream, ReadyExt, TryFutureExtExt, @@ -1136,8 +1136,7 @@ async fn share_encrypted_room( services .rooms .state_accessor - .room_state_get(other_room_id, &StateEventType::RoomEncryption, "") - .map(Result::into_is_ok) + .is_encrypted_room(other_room_id) }) .await } diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 447e54be0..f02655e65 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -65,12 +65,6 @@ pub(crate) async fn create_invite_route( return Err!(Request(Forbidden("Server is banned on this homeserver."))); } - if let Some(via) = &body.via { - if via.is_empty() { - return Err!(Request(InvalidParam("via field must not be empty."))); - } - } - let mut signed_event = utils::to_canonical_object(&body.event) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 8065ac55b..ac3f5f83e 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -5,7 +5,7 @@ use conduit::{ utils::{stream::TryIgnore, ReadyExt}, Err, Error, Result, }; -use database::{Deserialized, Handle, Map}; +use database::{Deserialized, Handle, Json, Map}; use futures::{StreamExt, TryFutureExt}; use ruma::{ events::{ @@ -56,41 +56,19 @@ impl crate::Service for Service { pub async fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, ) -> Result<()> { - let event_type = event_type.to_string(); - let count = self.services.globals.next_count()?; - - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&count.to_be_bytes()); - roomuserdataid.push(0xFF); - roomuserdataid.extend_from_slice(event_type.as_bytes()); - - let mut key = prefix; - key.extend_from_slice(event_type.as_bytes()); - if data.get("type").is_none() || data.get("content").is_none() { return Err!(Request(InvalidParam("Account data doesn't have all required fields."))); } - self.db.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), - ); - - let prev_key = (room_id, user_id, &event_type); - let prev = self.db.roomusertype_roomuserdataid.qry(&prev_key).await; - + let count = self.services.globals.next_count().unwrap(); + let roomuserdataid = (room_id, user_id, count, &event_type); self.db - .roomusertype_roomuserdataid - .insert(&key, &roomuserdataid); + .roomuserdataid_accountdata + .put(roomuserdataid, Json(data)); + + let key = (room_id, user_id, &event_type); + let prev = self.db.roomusertype_roomuserdataid.qry(&key).await; + self.db.roomusertype_roomuserdataid.put(key, roomuserdataid); // Remove old entry if let Ok(prev) = prev { diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 57a295d99..3638cb56c 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -4,7 +4,7 @@ use std::{ }; use conduit::{trace, utils, utils::rand, Error, Result, Server}; -use database::{Database, Deserialized, Map}; +use database::{Database, Deserialized, Json, Map}; use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, @@ -83,7 +83,7 @@ impl Data { .checked_add(1) .expect("counter must not overflow u64"); - self.global.insert(COUNTER, &counter.to_be_bytes()); + self.global.insert(COUNTER, counter.to_be_bytes()); Ok(*counter) } @@ -259,29 +259,21 @@ impl Data { pub async fn add_signing_key( &self, origin: &ServerName, new_keys: ServerSigningKeys, ) -> BTreeMap { - // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin).await; - - let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).map_err(Into::into)) + // (timo) Not atomic, but this is not critical + let mut keys: ServerSigningKeys = self + .server_signingkeys + .get(origin) + .await + .deserialized() .unwrap_or_else(|_| { // Just insert "now", it doesn't matter ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) }); - let ServerSigningKeys { - verify_keys, - old_verify_keys, - .. - } = new_keys; + keys.verify_keys.extend(new_keys.verify_keys); + keys.old_verify_keys.extend(new_keys.old_verify_keys); - keys.verify_keys.extend(verify_keys); - keys.old_verify_keys.extend(old_verify_keys); - - self.server_signingkeys.insert( - origin.as_bytes(), - &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - ); + self.server_signingkeys.raw_put(origin, Json(&keys)); let mut tree = keys.verify_keys; tree.extend( @@ -324,7 +316,7 @@ impl Data { #[inline] pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.global.insert(b"version", &new_version.to_be_bytes()); + self.global.raw_put(b"version", new_version); Ok(()) } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 334e71c6f..c953e7b1d 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -2,7 +2,7 @@ use conduit::{ debug_info, debug_warn, error, info, result::NotFound, utils::{stream::TryIgnore, IterStream, ReadyExt}, - warn, Err, Error, Result, + warn, Err, Result, }; use futures::{FutureExt, StreamExt}; use itertools::Itertools; @@ -37,10 +37,9 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> { // requires recreating the database from scratch. if users_count > 0 { let conduit_user = &services.globals.server_user; - if !services.users.exists(conduit_user).await { - error!("The {} server user does not exist, and the database is not new.", conduit_user); - return Err(Error::bad_database( + error!("The {conduit_user} server user does not exist, and the database is not new."); + return Err!(Database( "Cannot reuse an existing database after changing the server name, please delete the old one first.", )); } @@ -62,9 +61,9 @@ async fn fresh(services: &Services) -> Result<()> { .db .bump_database_version(DATABASE_VERSION)?; - db["global"].insert(b"feat_sha256_media", &[]); - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); + db["global"].insert(b"feat_sha256_media", []); + db["global"].insert(b"fix_bad_double_separator_in_state_cache", []); + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); // Create the admin room and server user on first run crate::admin::create_admin_room(services).await?; @@ -359,7 +358,7 @@ async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result< .await; db.db.cleanup()?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); + db["global"].insert(b"fix_bad_double_separator_in_state_cache", []); info!("Finished fixing"); Ok(()) @@ -440,7 +439,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) } db.db.cleanup()?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); info!("Finished fixing"); Ok(()) diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 4c3037571..bae6f2144 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -5,7 +5,7 @@ use conduit::{ utils::stream::{ReadyExt, TryIgnore}, Err, Result, }; -use database::{Deserialized, Ignore, Interfix, Map}; +use database::{Deserialized, Ignore, Interfix, Json, Map}; use futures::StreamExt; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, @@ -50,31 +50,21 @@ impl crate::Service for Service { #[implement(Service)] pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { let version = self.services.globals.next_count()?.to_string(); + let count = self.services.globals.next_count()?; - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); + let key = (user_id, &version); + self.db.backupid_algorithm.put(key, Json(backup_metadata)); - self.db.backupid_algorithm.insert( - &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), - ); - - self.db - .backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); + self.db.backupid_etag.put(key, count); Ok(version) } #[implement(Service)] pub async fn delete_backup(&self, user_id: &UserId, version: &str) { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.db.backupid_algorithm.remove(&key); - self.db.backupid_etag.remove(&key); + let key = (user_id, version); + self.db.backupid_algorithm.del(key); + self.db.backupid_etag.del(key); let key = (user_id, version, Interfix); self.db @@ -86,26 +76,21 @@ pub async fn delete_backup(&self, user_id: &UserId, version: &str) { } #[implement(Service)] -pub async fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw, -) -> Result { +pub async fn update_backup<'a>( + &self, user_id: &UserId, version: &'a str, backup_metadata: &Raw, +) -> Result<&'a str> { let key = (user_id, version); if self.db.backupid_algorithm.qry(&key).await.is_err() { return Err!(Request(NotFound("Tried to update nonexistent backup."))); } - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - + let count = self.services.globals.next_count().unwrap(); + self.db.backupid_etag.put(key, count); self.db .backupid_algorithm - .insert(&key, backup_metadata.json().get().as_bytes()); - self.db - .backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); + .put_raw(key, backup_metadata.json().get()); - Ok(version.to_owned()) + Ok(version) } #[implement(Service)] @@ -156,22 +141,13 @@ pub async fn add_key( return Err!(Request(NotFound("Tried to update nonexistent backup."))); } - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.db - .backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); - - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); + let count = self.services.globals.next_count().unwrap(); + self.db.backupid_etag.put(key, count); + let key = (user_id, version, room_id, session_id); self.db .backupkeyid_backup - .insert(&key, key_data.json().get().as_bytes()); + .put_raw(key, key_data.json().get()); Ok(()) } diff --git a/src/service/media/data.rs b/src/service/media/data.rs index b22718836..9afbd708f 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,13 +1,13 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use conduit::{ - debug, debug_info, trace, + debug, debug_info, err, utils::{str_from_bytes, stream::TryIgnore, string_from_bytes, ReadyExt}, Err, Error, Result, }; -use database::{Database, Map}; +use database::{Database, Interfix, Map}; use futures::StreamExt; -use ruma::{api::client::error::ErrorKind, http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId}; +use ruma::{http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId}; use super::{preview::UrlPreviewData, thumbnail::Dim}; @@ -37,39 +37,13 @@ impl Data { &self, mxc: &Mxc<'_>, user: Option<&UserId>, dim: &Dim, content_disposition: Option<&ContentDisposition>, content_type: Option<&str>, ) -> Result> { - let mut key: Vec = Vec::new(); - key.extend_from_slice(b"mxc://"); - key.extend_from_slice(mxc.server_name.as_bytes()); - key.extend_from_slice(b"/"); - key.extend_from_slice(mxc.media_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&dim.width.to_be_bytes()); - key.extend_from_slice(&dim.height.to_be_bytes()); - key.push(0xFF); - key.extend_from_slice( - content_disposition - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes(), - ); - key.push(0xFF); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); - - self.mediaid_file.insert(&key, &[]); - + let dim: &[u32] = &[dim.width, dim.height]; + let key = (mxc, dim, content_disposition, content_type); + let key = database::serialize_to_vec(key)?; + self.mediaid_file.insert(&key, []); if let Some(user) = user { - let mut key: Vec = Vec::new(); - key.extend_from_slice(b"mxc://"); - key.extend_from_slice(mxc.server_name.as_bytes()); - key.extend_from_slice(b"/"); - key.extend_from_slice(mxc.media_id.as_bytes()); - let user = user.as_bytes().to_vec(); - self.mediaid_user.insert(&key, &user); + let key = (mxc, user); + self.mediaid_user.put_raw(key, user); } Ok(key) @@ -78,33 +52,23 @@ impl Data { pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); - let mut prefix: Vec = Vec::new(); - prefix.extend_from_slice(b"mxc://"); - prefix.extend_from_slice(mxc.server_name.as_bytes()); - prefix.extend_from_slice(b"/"); - prefix.extend_from_slice(mxc.media_id.as_bytes()); - prefix.push(0xFF); - - trace!("MXC db prefix: {prefix:?}"); + let prefix = (mxc, Interfix); self.mediaid_file - .raw_keys_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() - .ready_for_each(|key| { - debug!("Deleting key: {:?}", key); - self.mediaid_file.remove(key); - }) + .ready_for_each(|key| self.mediaid_file.remove(key)) .await; self.mediaid_user - .raw_stream_prefix(&prefix) + .stream_prefix_raw(&prefix) .ignore_err() .ready_for_each(|(key, val)| { - if key.starts_with(&prefix) { - let user = str_from_bytes(val).unwrap_or_default(); - debug_info!("Deleting key {key:?} which was uploaded by user {user}"); + debug_assert!(key.starts_with(mxc.to_string().as_bytes()), "key should start with the mxc"); - self.mediaid_user.remove(key); - } + let user = str_from_bytes(val).unwrap_or_default(); + debug_info!("Deleting key {key:?} which was uploaded by user {user}"); + + self.mediaid_user.remove(key); }) .await; } @@ -113,16 +77,10 @@ impl Data { pub(super) async fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result>> { debug!("MXC URI: {mxc}"); - let mut prefix: Vec = Vec::new(); - prefix.extend_from_slice(b"mxc://"); - prefix.extend_from_slice(mxc.server_name.as_bytes()); - prefix.extend_from_slice(b"/"); - prefix.extend_from_slice(mxc.media_id.as_bytes()); - prefix.push(0xFF); - + let prefix = (mxc, Interfix); let keys: Vec> = self .mediaid_file - .raw_keys_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() .map(<[u8]>::to_vec) .collect() @@ -138,24 +96,17 @@ impl Data { } pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result { - let mut prefix: Vec = Vec::new(); - prefix.extend_from_slice(b"mxc://"); - prefix.extend_from_slice(mxc.server_name.as_bytes()); - prefix.extend_from_slice(b"/"); - prefix.extend_from_slice(mxc.media_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(&dim.width.to_be_bytes()); - prefix.extend_from_slice(&dim.height.to_be_bytes()); - prefix.push(0xFF); + let dim: &[u32] = &[dim.width, dim.height]; + let prefix = (mxc, dim, Interfix); let key = self .mediaid_file - .raw_keys_prefix(&prefix) + .keys_prefix_raw(&prefix) .ignore_err() .map(ToOwned::to_owned) .next() .await - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; + .ok_or_else(|| err!(Request(NotFound("Media not found"))))?; let mut parts = key.rsplit(|&b| b == 0xFF); @@ -215,9 +166,7 @@ impl Data { Ok(()) } - pub(super) fn set_url_preview( - &self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration, - ) -> Result<()> { + pub(super) fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: Duration) -> Result<()> { let mut value = Vec::::new(); value.extend_from_slice(×tamp.as_secs().to_be_bytes()); value.push(0xFF); diff --git a/src/service/media/migrations.rs b/src/service/media/migrations.rs index 2d1b39f9f..0e358d443 100644 --- a/src/service/media/migrations.rs +++ b/src/service/media/migrations.rs @@ -54,7 +54,7 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { services.globals.db.bump_database_version(13)?; } - db["global"].insert(b"feat_sha256_media", &[]); + db["global"].insert(b"feat_sha256_media", []); info!("Finished applying sha256_media"); Ok(()) } diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index 9c9d0ae3f..8522746fd 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -5,7 +5,7 @@ use conduit::{ utils::{stream::TryIgnore, ReadyExt}, Result, }; -use database::{Deserialized, Map}; +use database::{Deserialized, Json, Map}; use futures::Stream; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; @@ -107,14 +107,12 @@ impl Data { last_active_ts, status_msg, ); + let count = self.services.globals.next_count()?; let key = presenceid_key(count, user_id); - self.presenceid_presence - .insert(&key, &presence.to_json_bytes()?); - - self.userid_presenceid - .insert(user_id.as_bytes(), &count.to_be_bytes()); + self.presenceid_presence.raw_put(key, Json(presence)); + self.userid_presenceid.raw_put(user_id, count); if let Ok((last_count, _)) = last_presence { let key = presenceid_key(last_count, user_id); @@ -136,7 +134,7 @@ impl Data { let key = presenceid_key(count, user_id); self.presenceid_presence.remove(&key); - self.userid_presenceid.remove(user_id.as_bytes()); + self.userid_presenceid.remove(user_id); } pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { @@ -152,7 +150,11 @@ impl Data { #[inline] fn presenceid_key(count: u64, user_id: &UserId) -> Vec { - [count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat() + let cap = size_of::().saturating_add(user_id.as_bytes().len()); + let mut key = Vec::with_capacity(cap); + key.extend_from_slice(&count.to_be_bytes()); + key.extend_from_slice(user_id.as_bytes()); + key } #[inline] diff --git a/src/service/presence/presence.rs b/src/service/presence/presence.rs index 0d5c226bf..c43720034 100644 --- a/src/service/presence/presence.rs +++ b/src/service/presence/presence.rs @@ -35,10 +35,6 @@ impl Presence { serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database")) } - pub(super) fn to_json_bytes(&self) -> Result> { - serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON")) - } - /// Creates a PresenceEvent from available data. pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent { let now = utils::millis_since_unix_epoch(); diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index e7b1824ad..af15e332d 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -6,7 +6,7 @@ use conduit::{ utils::{stream::TryIgnore, string_from_bytes}, Err, PduEvent, Result, }; -use database::{Deserialized, Ignore, Interfix, Map}; +use database::{Deserialized, Ignore, Interfix, Json, Map}; use futures::{Stream, StreamExt}; use ipaddress::IPAddress; use ruma::{ @@ -68,18 +68,12 @@ impl Service { pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) { match pusher { set_pusher::v3::PusherAction::Post(data) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); - self.db - .senderkey_pusher - .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value")); + let key = (sender, &data.pusher.ids.pushkey); + self.db.senderkey_pusher.put(key, Json(pusher)); }, set_pusher::v3::PusherAction::Delete(ids) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(ids.pushkey.as_bytes()); - self.db.senderkey_pusher.remove(&key); + let key = (sender, &ids.pushkey); + self.db.senderkey_pusher.del(key); }, } } diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 2112ecefb..f366ffe2d 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -26,10 +26,10 @@ impl crate::Service for Service { } #[implement(Service)] -pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id.as_bytes(), &[]); } +pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id, []); } #[implement(Service)] -pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id.as_bytes()); } +pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id); } #[implement(Service)] pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.get(room_id).await.is_ok() } diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 9493dcc49..7a4da2a64 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -79,17 +79,9 @@ pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, return; }; - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - for ll_id in &user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.db.lazyloadedids.insert(&key, &[]); + let key = (user_id, device_id, room_id, ll_id); + self.db.lazyloadedids.put_raw(key, []); } } diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 8367eb72d..4ee390a5c 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -64,9 +64,9 @@ pub fn iter_ids(&self) -> impl Stream + Send + '_ { self.db.room #[inline] pub fn disable_room(&self, room_id: &RoomId, disabled: bool) { if disabled { - self.db.disabledroomids.insert(room_id.as_bytes(), &[]); + self.db.disabledroomids.insert(room_id, []); } else { - self.db.disabledroomids.remove(room_id.as_bytes()); + self.db.disabledroomids.remove(room_id); } } @@ -74,9 +74,9 @@ pub fn disable_room(&self, room_id: &RoomId, disabled: bool) { #[inline] pub fn ban_room(&self, room_id: &RoomId, banned: bool) { if banned { - self.db.bannedroomids.insert(room_id.as_bytes(), &[]); + self.db.bannedroomids.insert(room_id, []); } else { - self.db.bannedroomids.remove(room_id.as_bytes()); + self.db.bannedroomids.remove(room_id); } } diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index b9d042638..03e778389 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use conduit::{implement, Result}; -use database::{Deserialized, Map}; +use database::{Deserialized, Json, Map}; use ruma::{CanonicalJsonObject, EventId}; use crate::PduEvent; @@ -50,8 +50,5 @@ pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result { #[implement(Service)] #[tracing::instrument(skip(self, pdu), level = "debug")] pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) { - self.db.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ); + self.db.eventid_outlierpdu.raw_put(event_id, Json(pdu)); } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 8e0456582..4d570e6db 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -39,9 +39,10 @@ impl Data { } pub(super) fn add_relation(&self, from: u64, to: u64) { - let mut key = to.to_be_bytes().to_vec(); - key.extend_from_slice(&from.to_be_bytes()); - self.tofrom_relation.insert(&key, &[]); + const BUFSIZE: usize = size_of::() * 2; + + let key: &[u64] = &[to, from]; + self.tofrom_relation.aput_raw::(key, []); } pub(super) fn relations_until<'a>( @@ -78,9 +79,8 @@ impl Data { pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { for prev in event_ids { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[]); + let key = (room_id, prev); + self.referencedevents.put_raw(key, []); } } @@ -89,9 +89,7 @@ impl Data { self.referencedevents.qry(&key).await.is_ok() } - pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { - self.softfailedeventids.insert(event_id.as_bytes(), &[]); - } + pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { self.softfailedeventids.insert(event_id, []); } pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { self.softfailedeventids.get(event_id).await.is_ok() diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 74b649ef3..80a35e881 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -5,7 +5,7 @@ use conduit::{ utils::{stream::TryIgnore, ReadyExt}, Error, Result, }; -use database::{Deserialized, Map}; +use database::{Deserialized, Json, Map}; use futures::{Stream, StreamExt}; use ruma::{ events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, @@ -44,33 +44,19 @@ impl Data { pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { type KeyVal<'a> = (&'a RoomId, u64, &'a UserId); - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - // Remove old entry + let last_possible_key = (room_id, u64::MAX); self.readreceiptid_readreceipt - .rev_keys_from_raw(&last_possible_key) + .rev_keys_from(&last_possible_key) .ignore_err() .ready_take_while(|(r, ..): &KeyVal<'_>| *r == room_id) .ready_filter_map(|(r, c, u): KeyVal<'_>| (u == user_id).then_some((r, c, u))) - .ready_for_each(|old: KeyVal<'_>| { - // This is the old room_latest - self.readreceiptid_readreceipt.del(&old); - }) + .ready_for_each(|old: KeyVal<'_>| self.readreceiptid_readreceipt.del(old)) .await; - let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); - room_latest_id.push(0xFF); - room_latest_id.extend_from_slice(user_id.as_bytes()); - - self.readreceiptid_readreceipt.insert( - &room_latest_id, - &serde_json::to_vec(event).expect("EduEvent::to_string always works"), - ); + let count = self.services.globals.next_count().unwrap(); + let latest_id = (room_id, count, user_id); + self.readreceiptid_readreceipt.put(latest_id, Json(event)); } pub(super) fn readreceipts_since<'a>( @@ -113,15 +99,11 @@ impl Data { } pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes()); + let key = (room_id, user_id); + let next_count = self.services.globals.next_count().unwrap(); - self.roomuserid_lastprivatereadupdate - .insert(&key, &self.services.globals.next_count().unwrap().to_be_bytes()); + self.roomuserid_privateread.put(key, count); + self.roomuserid_lastprivatereadupdate.put(key, next_count); } pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index bd8fdcc94..609c0e07e 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{mem::size_of_val, sync::Arc}; use conduit::{err, implement, utils, Result}; use database::{Deserialized, Map}; @@ -46,6 +46,8 @@ impl crate::Service for Service { #[implement(Service)] pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { + const BUFSIZE: usize = size_of::(); + if let Ok(shorteventid) = self .db .eventid_shorteventid @@ -57,12 +59,15 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { } let shorteventid = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&shorteventid) == BUFSIZE, "buffer requirement changed"); + self.db .eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes()); + .raw_aput::(event_id, shorteventid); + self.db .shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes()); + .aput_raw::(shorteventid, event_id); shorteventid } @@ -77,13 +82,17 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> .map(|(i, result)| match result { Ok(ref short) => utils::u64_from_u8(short), Err(_) => { + const BUFSIZE: usize = size_of::(); + let short = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); + self.db .eventid_shorteventid - .insert(event_ids[i], &short.to_be_bytes()); + .raw_aput::(event_ids[i], short); self.db .shorteventid_eventid - .insert(&short.to_be_bytes(), event_ids[i]); + .aput_raw::(short, event_ids[i]); short }, @@ -103,7 +112,9 @@ pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &s #[implement(Service)] pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { - let key = (event_type.to_string(), state_key); + const BUFSIZE: usize = size_of::(); + + let key = (event_type, state_key); if let Ok(shortstatekey) = self .db .statekey_shortstatekey @@ -114,17 +125,16 @@ pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, sta return shortstatekey; } - let mut key = event_type.to_string().as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(state_key.as_bytes()); - let shortstatekey = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&shortstatekey) == BUFSIZE, "buffer requirement changed"); + self.db .statekey_shortstatekey - .insert(&key, &shortstatekey.to_be_bytes()); + .put_aput::(key, shortstatekey); + self.db .shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &key); + .aput_put::(shortstatekey, key); shortstatekey } @@ -177,6 +187,8 @@ pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(State /// Returns (shortstatehash, already_existed) #[implement(Service)] pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { + const BUFSIZE: usize = size_of::(); + if let Ok(shortstatehash) = self .db .statehash_shortstatehash @@ -188,9 +200,11 @@ pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, boo } let shortstatehash = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&shortstatehash) == BUFSIZE, "buffer requirement changed"); + self.db .statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes()); + .raw_aput::(state_hash, shortstatehash); (shortstatehash, false) } @@ -208,10 +222,15 @@ pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { .await .deserialized() .unwrap_or_else(|_| { + const BUFSIZE: usize = size_of::(); + let short = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); + self.db .roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes()); + .raw_aput::(room_id, short); + short }) } diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 7265038fd..813f48aed 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -36,12 +36,12 @@ impl Data { _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) { self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes()); + .raw_put(room_id, new_shortstatehash); } pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) { self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes()); + .put(shorteventid, shortstatehash); } pub(super) async fn set_forward_extremities( @@ -57,12 +57,9 @@ impl Data { .ready_for_each(|key| self.roomid_pduleaves.remove(key)) .await; - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - for event_id in event_ids { - let mut key = prefix.clone(); - key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes()); + for event_id in &event_ids { + let key = (room_id, event_id); + self.roomid_pduleaves.put_raw(key, event_id); } } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 19f1f1413..561db18a5 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -481,4 +481,10 @@ impl Service { .await .map(|content: RoomEncryptionEventContent| content.algorithm) } + + pub async fn is_encrypted_room(&self, room_id: &RoomId) -> bool { + self.room_state_get(room_id, &StateEventType::RoomEncryption, "") + .await + .is_ok() + } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 6e01e49df..c06c8107f 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -4,7 +4,7 @@ use std::{ }; use conduit::{utils::stream::TryIgnore, Result}; -use database::{Deserialized, Interfix, Map}; +use database::{serialize_to_vec, Deserialized, Interfix, Json, Map}; use futures::{Stream, StreamExt}; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, @@ -63,71 +63,62 @@ impl Data { } pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[]); + let key = (user_id, room_id); + + self.roomuseroncejoinedids.put_raw(key, []); } pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { - let roomid = room_id.as_bytes().to_vec(); + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); - let mut roomuser_id = roomid.clone(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); + self.userroomid_joined.insert(&userroom_id, []); + self.roomuserid_joined.insert(&roomuser_id, []); - self.userroomid_joined.insert(&userroom_id, &[]); - self.roomuserid_joined.insert(&roomuser_id, &[]); self.userroomid_invitestate.remove(&userroom_id); self.roomuserid_invitecount.remove(&roomuser_id); + self.userroomid_leftstate.remove(&userroom_id); self.roomuserid_leftcount.remove(&roomuser_id); - self.roomid_inviteviaservers.remove(&roomid); + self.roomid_inviteviaservers.remove(room_id); } pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { - let roomid = room_id.as_bytes().to_vec(); - - let mut roomuser_id = roomid.clone(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate.insert( - &userroom_id, - &serde_json::to_vec(&Vec::>::new()).unwrap(), - ); // TODO - self.roomuserid_leftcount - .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); + + // (timo) TODO + let leftstate = Vec::>::new(); + let count = self.services.globals.next_count().unwrap(); + + self.userroomid_leftstate + .raw_put(&userroom_id, Json(leftstate)); + self.roomuserid_leftcount.raw_put(&roomuser_id, count); + self.userroomid_joined.remove(&userroom_id); self.roomuserid_joined.remove(&roomuser_id); + self.userroomid_invitestate.remove(&userroom_id); self.roomuserid_invitecount.remove(&roomuser_id); - self.roomid_inviteviaservers.remove(&roomid); + self.roomid_inviteviaservers.remove(room_id); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = (user_id, room_id); + let roomuser_id = (room_id, user_id); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - self.userroomid_leftstate.remove(&userroom_id); - self.roomuserid_leftcount.remove(&roomuser_id); + self.userroomid_leftstate.del(userroom_id); + self.roomuserid_leftcount.del(roomuser_id); } /// Returns an iterator over all rooms a user was invited to. diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index edfae5291..077eee104 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -3,13 +3,13 @@ mod data; use std::{collections::HashSet, sync::Arc}; use conduit::{ - err, + err, is_not_empty, utils::{stream::TryIgnore, ReadyExt, StreamTools}, warn, Result, }; use data::Data; -use database::{Deserialized, Ignore, Interfix}; -use futures::{Stream, StreamExt}; +use database::{serialize_to_vec, Deserialized, Ignore, Interfix, Json}; +use futures::{stream::iter, Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ @@ -547,50 +547,37 @@ impl Service { .unwrap_or(0), ); - self.db - .roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes()); - - self.db - .roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes()); + self.db.roomid_joinedcount.raw_put(room_id, joinedcount); + self.db.roomid_invitedcount.raw_put(room_id, invitedcount); self.room_servers(room_id) .ready_for_each(|old_joined_server| { - if !joined_servers.remove(old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.db.roomserverids.remove(&roomserver_id); - self.db.serverroomids.remove(&serverroom_id); + if joined_servers.remove(old_joined_server) { + return; } + + // Server not in room anymore + let roomserver_id = (room_id, old_joined_server); + let serverroom_id = (old_joined_server, room_id); + + self.db.roomserverids.del(roomserver_id); + self.db.serverroomids.del(serverroom_id); }) .await; // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); + for server in &joined_servers { + let roomserver_id = (room_id, server); + let serverroom_id = (server, room_id); - self.db.roomserverids.insert(&roomserver_id, &[]); - self.db.serverroomids.insert(&serverroom_id, &[]); + self.db.roomserverids.put_raw(roomserver_id, []); + self.db.serverroomids.put_raw(serverroom_id, []); } self.db .appservice_in_room_cache .write() - .unwrap() + .expect("locked") .remove(room_id); } @@ -598,44 +585,44 @@ impl Service { &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, invite_via: Option>, ) { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); + + self.db + .userroomid_invitestate + .raw_put(&userroom_id, Json(last_state.unwrap_or_default())); - self.db.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), - ); self.db .roomuserid_invitecount - .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + self.db.userroomid_joined.remove(&userroom_id); self.db.roomuserid_joined.remove(&roomuser_id); + self.db.userroomid_leftstate.remove(&userroom_id); self.db.roomuserid_leftcount.remove(&roomuser_id); - if let Some(servers) = invite_via.as_deref() { + if let Some(servers) = invite_via.filter(is_not_empty!()) { self.add_servers_invite_via(room_id, servers).await; } } #[tracing::instrument(skip(self, servers), level = "debug")] - pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) { - let mut prev_servers: Vec<_> = self + pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: Vec) { + let mut servers: Vec<_> = self .servers_invite_via(room_id) .map(ToOwned::to_owned) + .chain(iter(servers.into_iter())) .collect() .await; - prev_servers.extend(servers.to_owned()); - prev_servers.sort_unstable(); - prev_servers.dedup(); + servers.sort_unstable(); + servers.dedup(); - let servers = prev_servers + let servers = servers .iter() .map(|server| server.as_bytes()) .collect_vec() diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index cb85cf19c..c51b78568 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -11,7 +11,7 @@ use conduit::{ utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, Err, PduCount, PduEvent, Result, }; -use database::{Database, Deserialized, KeyVal, Map}; +use database::{Database, Deserialized, Json, KeyVal, Map}; use futures::{FutureExt, Stream, StreamExt}; use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; @@ -168,10 +168,7 @@ impl Data { } pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - ); + self.pduid_pdu.raw_put(pdu_id, Json(json)); self.lasttimelinecount_cache .lock() @@ -183,13 +180,10 @@ impl Data { } pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - ); + self.pduid_pdu.raw_put(pdu_id, Json(json)); - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id); - self.eventid_outlierpdu.remove(event_id.as_bytes()); + self.eventid_pduid.insert(event_id, pdu_id); + self.eventid_outlierpdu.remove(event_id); } /// Removes a pdu and creates a new one with the same id. @@ -328,5 +322,5 @@ pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount { fn increment(db: &Arc, key: &[u8]) { let old = db.get_blocking(key); let new = utils::increment(old.ok().as_deref()); - db.insert(key, &new); + db.insert(key, new); } diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index d4d9874c2..96b009f85 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -38,20 +38,13 @@ impl Data { } pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); + let userroom_id = (user_id, room_id); + self.userroomid_highlightcount.put(userroom_id, 0_u64); + self.userroomid_notificationcount.put(userroom_id, 0_u64); - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes()); - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes()); - - self.roomuserid_lastnotificationread - .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + let roomuser_id = (room_id, user_id); + let count = self.services.globals.next_count().unwrap(); + self.roomuserid_lastnotificationread.put(roomuser_id, count); } pub(super) async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { @@ -89,11 +82,8 @@ impl Data { .await .expect("room exists"); - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()); + let key: &[u64] = &[shortroomid, token]; + self.roomsynctoken_shortstatehash.put(key, shortstatehash); } pub(super) async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 96d4a6a91..f75a212c7 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -146,8 +146,7 @@ impl Data { } pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) { - self.servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()); + self.servername_educount.raw_put(server_name, last_count); } pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 { diff --git a/src/service/sending/dest.rs b/src/service/sending/dest.rs index 9968acd76..234a0b906 100644 --- a/src/service/sending/dest.rs +++ b/src/service/sending/dest.rs @@ -12,7 +12,7 @@ pub enum Destination { #[implement(Destination)] #[must_use] -pub fn get_prefix(&self) -> Vec { +pub(super) fn get_prefix(&self) -> Vec { match self { Self::Normal(server) => { let len = server.as_bytes().len().saturating_add(1); diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index f75f1bcd8..d2865d882 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -8,7 +8,7 @@ use conduit::{ utils::{hash, string::EMPTY}, Error, Result, }; -use database::{Deserialized, Map}; +use database::{Deserialized, Json, Map}; use ruma::{ api::client::{ error::ErrorKind, @@ -217,21 +217,14 @@ pub fn get_uiaa_request( #[implement(Service)] fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); + let key = (user_id, device_id, session); if let Some(uiaainfo) = uiaainfo { - self.db.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - ); - } else { self.db .userdevicesessionid_uiaainfo - .remove(&userdevicesessionid); + .put(key, Json(uiaainfo)); + } else { + self.db.userdevicesessionid_uiaainfo.del(key); } } diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index fca637255..adc85fe60 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -121,10 +121,7 @@ impl Service { } #[inline] - pub fn update_check_for_updates_id(&self, id: u64) { - self.db - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes()); - } + pub fn update_check_for_updates_id(&self, id: u64) { self.db.raw_put(LAST_CHECK_FOR_UPDATES_COUNT, id); } pub async fn last_check_for_updates_id(&self) -> u64 { self.db diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index a99a7df4b..589aee8a1 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -3,18 +3,19 @@ use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc}; use conduit::{ debug_warn, err, utils, utils::{stream::TryIgnore, string::Unquoted, ReadyExt}, - warn, Err, Error, Result, Server, + Err, Error, Result, Server, }; -use database::{Deserialized, Ignore, Interfix, Map}; -use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt}; +use database::{Deserialized, Ignore, Interfix, Json, Map}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt}; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType, StateEventType}, + events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType}, serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, }; +use serde_json::json; use crate::{account_data, admin, globals, rooms, Dep}; @@ -194,22 +195,16 @@ impl Service { /// Hash and set the user's password to the Argon2 hash pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::hash::password(password) { - self.db - .userid_password - .insert(user_id.as_bytes(), hash.as_bytes()); - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.db.userid_password.insert(user_id.as_bytes(), b""); - Ok(()) - } + password + .map(utils::hash::password) + .transpose() + .map_err(|e| err!(Request(InvalidParam("Password does not meet the requirements: {e}"))))? + .map_or_else( + || self.db.userid_password.insert(user_id, b""), + |hash| self.db.userid_password.insert(user_id, hash), + ); + + Ok(()) } /// Returns the displayname of a user on this homeserver. @@ -221,11 +216,9 @@ impl Service { /// need to nofify all rooms of this change. pub fn set_displayname(&self, user_id: &UserId, displayname: Option) { if let Some(displayname) = displayname { - self.db - .userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes()); + self.db.userid_displayname.insert(user_id, displayname); } else { - self.db.userid_displayname.remove(user_id.as_bytes()); + self.db.userid_displayname.remove(user_id); } } @@ -237,11 +230,9 @@ impl Service { /// Sets a new avatar_url or removes it if avatar_url is None. pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) { if let Some(avatar_url) = avatar_url { - self.db - .userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes()); + self.db.userid_avatarurl.insert(user_id, &avatar_url); } else { - self.db.userid_avatarurl.remove(user_id.as_bytes()); + self.db.userid_avatarurl.remove(user_id); } } @@ -253,11 +244,9 @@ impl Service { /// Sets a new avatar_url or removes it if avatar_url is None. pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) { if let Some(blurhash) = blurhash { - self.db - .userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes()); + self.db.userid_blurhash.insert(user_id, blurhash); } else { - self.db.userid_blurhash.remove(user_id.as_bytes()); + self.db.userid_blurhash.remove(user_id); } } @@ -269,41 +258,29 @@ impl Service { // This method should never be called for nonexistent users. We shouldn't assert // though... if !self.exists(user_id).await { - warn!("Called create_device for non-existent user {} in database", user_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); + return Err!(Request(InvalidParam(error!("Called create_device for non-existent {user_id}")))); } - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); + let key = (user_id, device_id); + let val = Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: client_ip, + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }; increment(&self.db.userid_devicelistversion, user_id.as_bytes()); - - self.db.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: client_ip, - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - ); - - self.set_token(user_id, device_id, token).await?; - - Ok(()) + self.db.userdeviceid_metadata.put(key, Json(val)); + self.set_token(user_id, device_id, token).await } /// Removes a device from a user. pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); + let userdeviceid = (user_id, device_id); // Remove tokens - if let Ok(old_token) = self.db.userdeviceid_token.get(&userdeviceid).await { - self.db.userdeviceid_token.remove(&userdeviceid); + if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await { + self.db.userdeviceid_token.del(userdeviceid); self.db.token_userdeviceid.remove(&old_token); } @@ -320,7 +297,7 @@ impl Service { increment(&self.db.userid_devicelistversion, user_id.as_bytes()); - self.db.userdeviceid_metadata.remove(&userdeviceid); + self.db.userdeviceid_metadata.del(userdeviceid); } /// Returns an iterator over all device ids of this user. @@ -333,6 +310,11 @@ impl Service { .map(|(_, device_id): (Ignore, &DeviceId)| device_id) } + pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result { + let key = (user_id, device_id); + self.db.userdeviceid_token.qry(&key).await.deserialized() + } + /// Replaces the access token of one device. pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { let key = (user_id, device_id); @@ -352,15 +334,8 @@ impl Service { } // Assign token to user device combination - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - self.db - .userdeviceid_token - .insert(&userdeviceid, token.as_bytes()); - self.db - .token_userdeviceid - .insert(token.as_bytes(), &userdeviceid); + self.db.userdeviceid_token.put_raw(key, token); + self.db.token_userdeviceid.raw_put(token, key); Ok(()) } @@ -393,14 +368,12 @@ impl Service { .as_bytes(), ); - self.db.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - ); - self.db - .userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); + .onetimekeyid_onetimekeys + .raw_put(key, Json(one_time_key_value)); + + let count = self.services.globals.next_count().unwrap(); + self.db.userid_lastonetimekeyupdate.raw_put(user_id, count); Ok(()) } @@ -417,9 +390,8 @@ impl Service { pub async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, ) -> Result<(OwnedDeviceKeyId, Raw)> { - self.db - .userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); + let count = self.services.globals.next_count()?.to_be_bytes(); + self.db.userid_lastonetimekeyupdate.insert(user_id, count); let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -488,15 +460,9 @@ impl Service { } pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.db.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - ); + let key = (user_id, device_id); + self.db.keyid_key.put(key, Json(device_keys)); self.mark_device_key_update(user_id).await; } @@ -611,13 +577,8 @@ impl Service { .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? .insert(signature.0, signature.1.into()); - let mut key = target_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(key_id.as_bytes()); - self.db.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - ); + let key = (target_id, key_id); + self.db.keyid_key.put(key, Json(cross_signing_key)); self.mark_device_key_update(target_id).await; @@ -640,34 +601,21 @@ impl Service { } pub async fn mark_device_key_update(&self, user_id: &UserId) { - let count = self.services.globals.next_count().unwrap().to_be_bytes(); + let count = self.services.globals.next_count().unwrap(); - let rooms_joined = self.services.state_cache.rooms_joined(user_id); - - pin_mut!(rooms_joined); - while let Some(room_id) = rooms_joined.next().await { + self.services + .state_cache + .rooms_joined(user_id) // Don't send key updates to unencrypted rooms - if self - .services - .state_accessor - .room_state_get(room_id, &StateEventType::RoomEncryption, "") - .await - .is_err() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - - self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); - } + .filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id)) + .ready_for_each(|room_id| { + let key = (room_id, count); + self.db.keychangeid_userid.put_raw(key, user_id); + }) + .await; - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); + let key = (user_id, count); + self.db.keychangeid_userid.put_raw(key, user_id); } pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result> { @@ -681,12 +629,7 @@ impl Service { where F: Fn(&UserId) -> bool + Send + Sync, { - let key = self - .db - .keyid_key - .get(key_id) - .await - .deserialized::()?; + let key: serde_json::Value = self.db.keyid_key.get(key_id).await.deserialized()?; let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; let raw_value = serde_json::value::to_raw_value(&cleaned)?; @@ -718,29 +661,29 @@ impl Service { } pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { - let key_id = self.db.userid_usersigningkeyid.get(user_id).await?; - - self.db.keyid_key.get(&*key_id).await.deserialized() + self.db + .userid_usersigningkeyid + .get(user_id) + .and_then(|key_id| self.db.keyid_key.get(&*key_id)) + .await + .deserialized() } pub async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, ) { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.db.todeviceid_events.insert(&key, &value); + let count = self.services.globals.next_count().unwrap(); + + let key = (target_user_id, target_device_id, count); + self.db.todeviceid_events.put( + key, + Json(json!({ + "type": event_type, + "sender": sender, + "content": content, + })), + ); } pub fn get_to_device_events<'a>( @@ -783,13 +726,8 @@ impl Service { pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { increment(&self.db.userid_devicelistversion, user_id.as_bytes()); - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - self.db.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - ); + let key = (user_id, device_id); + self.db.userdeviceid_metadata.put(key, Json(device)); Ok(()) } @@ -824,23 +762,15 @@ impl Service { pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String { let filter_id = utils::random_string(4); - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - self.db - .userfilterid_filter - .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json")); + let key = (user_id, &filter_id); + self.db.userfilterid_filter.put(key, Json(filter)); filter_id } pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result { - self.db - .userfilterid_filter - .qry(&(user_id, filter_id)) - .await - .deserialized() + let key = (user_id, filter_id); + self.db.userfilterid_filter.qry(&key).await.deserialized() } /// Creates an OpenID token, which can be used to prove that a user has @@ -913,17 +843,13 @@ impl Service { /// Sets a new profile key value, removes the key if value is None pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option) { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(value) = profile_key_value { - let value = serde_json::to_vec(&value).unwrap(); + let key = (user_id, profile_key); - self.db.useridprofilekey_value.insert(&key, &value); + if let Some(value) = profile_key_value { + self.db.useridprofilekey_value.put(key, value); } else { - self.db.useridprofilekey_value.remove(&key); + self.db.useridprofilekey_value.del(key); } } @@ -945,17 +871,13 @@ impl Service { /// Sets a new timezone or removes it if timezone is None. pub fn set_timezone(&self, user_id: &UserId, timezone: Option) { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - // TODO: insert to the stable MSC4175 key when it's stable + let key = (user_id, "us.cloke.msc4175.tz"); + if let Some(timezone) = timezone { - self.db - .useridprofilekey_value - .insert(&key, timezone.as_bytes()); + self.db.useridprofilekey_value.put_raw(key, &timezone); } else { - self.db.useridprofilekey_value.remove(&key); + self.db.useridprofilekey_value.del(key); } } } @@ -1012,5 +934,5 @@ where fn increment(db: &Arc, key: &[u8]) { let old = db.get_blocking(key); let new = utils::increment(old.ok().as_deref()); - db.insert(key, &new); + db.insert(key, new); } From 89b5c4ee1c6cfd662a60d704b1ffec736d7a3600 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 11 Oct 2024 20:32:32 +0000 Subject: [PATCH 080/245] add timepoint_from_now to complement timepoint_ago in utils Signed-off-by: Jason Volk --- src/core/utils/mod.rs | 2 +- src/core/utils/time.rs | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 4dbecf91a..132213412 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -34,7 +34,7 @@ pub use self::{ stream::{IterStream, ReadyExt, Tools as StreamTools, TryReadyExt}, string::{str_from_bytes, string_from_bytes}, sys::available_parallelism, - time::now_millis as millis_since_unix_epoch, + time::{now_millis as millis_since_unix_epoch, timepoint_ago, timepoint_from_now}, }; #[inline] diff --git a/src/core/utils/time.rs b/src/core/utils/time.rs index 04f47ac38..f96a27d00 100644 --- a/src/core/utils/time.rs +++ b/src/core/utils/time.rs @@ -22,6 +22,13 @@ pub fn timepoint_ago(duration: Duration) -> Result { .ok_or_else(|| err!(Arithmetic("Duration {duration:?} is too large"))) } +#[inline] +pub fn timepoint_from_now(duration: Duration) -> Result { + SystemTime::now() + .checked_add(duration) + .ok_or_else(|| err!(Arithmetic("Duration {duration:?} is too large"))) +} + #[inline] pub fn parse_duration(duration: &str) -> Result { cyborgtime::parse_duration(duration) From 1a09eb0f0235a1dfe7c51f525f656cffad62b60d Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 13 Oct 2024 00:57:08 +0000 Subject: [PATCH 081/245] use string::EMPTY; minor formatting and misc cleanups Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 12 ++++++++---- src/api/client/membership.rs | 11 +++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 350e08c6a..fd8c39f77 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -6,7 +6,7 @@ use std::{ }; use api::client::validate_and_add_event_id; -use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result}; +use conduit::{debug, debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, Result}; use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, @@ -57,7 +57,9 @@ pub(super) async fn get_auth_chain(&self, event_id: Box) -> Result Result { - if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + if self.body.len() < 2 + || !self.body[0].trim().starts_with("```") + || self.body.last().unwrap_or(&EMPTY).trim() != "```" { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", @@ -134,7 +136,9 @@ pub(super) async fn get_remote_pdu_list( )); } - if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + if self.body.len() < 2 + || !self.body[0].trim().starts_with("```") + || self.body.last().unwrap_or(&EMPTY).trim() != "```" { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", @@ -843,7 +847,7 @@ pub(super) async fn database_stats( &self, property: Option, map: Option, ) -> Result { let property = property.unwrap_or_else(|| "rocksdb.stats".to_owned()); - let map_name = map.as_ref().map_or(utils::string::EMPTY, String::as_str); + let map_name = map.as_ref().map_or(EMPTY, String::as_str); let mut out = String::new(); for (name, map) in self.services.db.iter_maps() { diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 060355921..a7a5b1668 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -658,13 +658,16 @@ pub async fn join_room_by_id_helper( }); } - if services + let server_in_room = services .rooms .state_cache .server_in_room(services.globals.server_name(), room_id) - .await || servers.is_empty() - || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) - { + .await; + + let local_join = + server_in_room || servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])); + + if local_join { join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) .boxed() .await From d82ea331cfdcd51c2c746618deb26e1fd220abc0 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 14 Oct 2024 05:16:18 +0000 Subject: [PATCH 082/245] add random shuffle util Signed-off-by: Jason Volk --- src/core/utils/mod.rs | 2 +- src/core/utils/rand.rs | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 132213412..3adecc6c1 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -30,7 +30,7 @@ pub use self::{ json::{deserialize_from_str, to_canonical_object}, math::clamp, mutex_map::{Guard as MutexMapGuard, MutexMap}, - rand::string as random_string, + rand::{shuffle, string as random_string}, stream::{IterStream, ReadyExt, Tools as StreamTools, TryReadyExt}, string::{str_from_bytes, string_from_bytes}, sys::available_parallelism, diff --git a/src/core/utils/rand.rs b/src/core/utils/rand.rs index d717c4bdc..9e6fc7a81 100644 --- a/src/core/utils/rand.rs +++ b/src/core/utils/rand.rs @@ -4,7 +4,12 @@ use std::{ }; use arrayvec::ArrayString; -use rand::{thread_rng, Rng}; +use rand::{seq::SliceRandom, thread_rng, Rng}; + +pub fn shuffle(vec: &mut [T]) { + let mut rng = thread_rng(); + vec.shuffle(&mut rng); +} pub fn string(length: usize) -> String { thread_rng() From c0939c3e9a9d7c193e8092333cd9289499540463 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 11 Oct 2024 18:57:59 +0000 Subject: [PATCH 083/245] Refactor server_keys service/interface and related callsites Signed-off-by: Jason Volk Signed-off-by: strawberry --- Cargo.lock | 26 +- Cargo.toml | 2 +- src/admin/debug/commands.rs | 173 +++---- src/admin/debug/mod.rs | 17 +- src/admin/query/globals.rs | 13 +- src/api/client/membership.rs | 252 +++------- src/api/client/mod.rs | 2 +- src/api/router/args.rs | 2 +- src/api/router/auth.rs | 214 ++++---- src/api/server/invite.rs | 11 +- src/api/server/key.rs | 70 +-- src/api/server/send.rs | 22 +- src/api/server/send_join.rs | 24 +- src/api/server/send_leave.rs | 13 +- src/core/config/mod.rs | 28 -- src/core/error/mod.rs | 4 + src/core/pdu/mod.rs | 25 +- src/service/globals/data.rs | 118 +---- src/service/globals/mod.rs | 50 +- src/service/rooms/event_handler/mod.rs | 105 ++-- src/service/rooms/timeline/mod.rs | 45 +- src/service/sending/mod.rs | 4 +- src/service/sending/send.rs | 24 +- src/service/server_keys/acquire.rs | 175 +++++++ src/service/server_keys/get.rs | 86 ++++ src/service/server_keys/keypair.rs | 64 +++ src/service/server_keys/mod.rs | 648 +++++-------------------- src/service/server_keys/request.rs | 97 ++++ src/service/server_keys/sign.rs | 18 + src/service/server_keys/verify.rs | 33 ++ 30 files changed, 1006 insertions(+), 1359 deletions(-) create mode 100644 src/service/server_keys/acquire.rs create mode 100644 src/service/server_keys/get.rs create mode 100644 src/service/server_keys/keypair.rs create mode 100644 src/service/server_keys/request.rs create mode 100644 src/service/server_keys/sign.rs create mode 100644 src/service/server_keys/verify.rs diff --git a/Cargo.lock b/Cargo.lock index db1394ce6..4ac7cc35f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2976,7 +2976,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "assign", "js_int", @@ -2998,7 +2998,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "js_int", "ruma-common", @@ -3010,7 +3010,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "as_variant", "assign", @@ -3033,7 +3033,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "as_variant", "base64 0.22.1", @@ -3063,7 +3063,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3087,7 +3087,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "bytes", "http", @@ -3105,7 +3105,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "js_int", "thiserror", @@ -3114,7 +3114,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "js_int", "ruma-common", @@ -3124,7 +3124,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "cfg-if", "once_cell", @@ -3140,7 +3140,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "js_int", "ruma-common", @@ -3152,7 +3152,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "headers", "http", @@ -3165,7 +3165,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3181,7 +3181,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=3109496a1f91357c89cbb57cf86f179e2cb013e7#3109496a1f91357c89cbb57cf86f179e2cb013e7" +source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" dependencies = [ "futures-util", "itertools 0.13.0", diff --git a/Cargo.toml b/Cargo.toml index 0a98befd8..966c28183 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -315,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "3109496a1f91357c89cbb57cf86f179e2cb013e7" +rev = "d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" features = [ "compat", "rand", diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index fd8c39f77..7fe8addfa 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -1,19 +1,17 @@ use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, fmt::Write, sync::Arc, time::{Instant, SystemTime}, }; -use api::client::validate_and_add_event_id; -use conduit::{debug, debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, Result}; +use conduit::{debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, Result}; use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; use crate::admin_command; @@ -219,7 +217,7 @@ pub(super) async fn get_remote_pdu( })?; trace!("Attempting to parse PDU: {:?}", &response.pdu); - let parsed_pdu = { + let _parsed_pdu = { let parsed_result = self .services .rooms @@ -241,22 +239,11 @@ pub(super) async fn get_remote_pdu( vec![(event_id, value, room_id)] }; - let pub_key_map = RwLock::new(BTreeMap::new()); - - debug!("Attempting to fetch homeserver signing keys for {server}"); - self.services - .server_keys - .fetch_required_signing_keys(parsed_pdu.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) - .await - .unwrap_or_else(|e| { - warn!("Could not fetch all signatures for PDUs from {server}: {e:?}"); - }); - info!("Attempting to handle event ID {event_id} as backfilled PDU"); self.services .rooms .timeline - .backfill_pdu(&server, response.pdu, &pub_key_map) + .backfill_pdu(&server, response.pdu) .await?; let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); @@ -433,12 +420,10 @@ pub(super) async fn sign_json(&self) -> Result { let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n"); match serde_json::from_str(&string) { Ok(mut value) => { - ruma::signatures::sign_json( - self.services.globals.server_name().as_str(), - self.services.globals.keypair(), - &mut value, - ) - .expect("our request json is what ruma expects"); + self.services + .server_keys + .sign_json(&mut value) + .expect("our request json is what ruma expects"); let json_text = serde_json::to_string_pretty(&value).expect("canonical json is valid json"); Ok(RoomMessageEventContent::text_plain(json_text)) }, @@ -456,27 +441,31 @@ pub(super) async fn verify_json(&self) -> Result { } let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n"); - match serde_json::from_str(&string) { - Ok(value) => { - let pub_key_map = RwLock::new(BTreeMap::new()); - - self.services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - - let pub_key_map = pub_key_map.read().await; - match ruma::signatures::verify_json(&pub_key_map, &value) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Signature correct")), - Err(e) => Ok(RoomMessageEventContent::text_plain(format!( - "Signature verification failed: {e}" - ))), - } + match serde_json::from_str::(&string) { + Ok(value) => match self.services.server_keys.verify_json(&value, None).await { + Ok(()) => Ok(RoomMessageEventContent::text_plain("Signature correct")), + Err(e) => Ok(RoomMessageEventContent::text_plain(format!( + "Signature verification failed: {e}" + ))), }, Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Invalid json: {e}"))), } } +#[admin_command] +pub(super) async fn verify_pdu(&self, event_id: Box) -> Result { + let mut event = self.services.rooms.timeline.get_pdu_json(&event_id).await?; + + event.remove("event_id"); + let msg = match self.services.server_keys.verify_event(&event, None).await { + Ok(ruma::signatures::Verified::Signatures) => "signatures OK, but content hash failed (redaction).", + Ok(ruma::signatures::Verified::All) => "signatures and hashes OK.", + Err(e) => return Err(e), + }; + + Ok(RoomMessageEventContent::notice_plain(msg)) +} + #[admin_command] #[tracing::instrument(skip(self))] pub(super) async fn first_pdu_in_room(&self, room_id: Box) -> Result { @@ -557,7 +546,6 @@ pub(super) async fn force_set_room_state_from_server( let room_version = self.services.rooms.state.get_room_version(&room_id).await?; let mut state: HashMap> = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); let remote_state_response = self .services @@ -571,38 +559,28 @@ pub(super) async fn force_set_room_state_from_server( ) .await?; - let mut events = Vec::with_capacity(remote_state_response.pdus.len()); - for pdu in remote_state_response.pdus.clone() { - events.push( - match self - .services - .rooms - .event_handler - .parse_incoming_pdu(&pdu) - .await - { - Ok(t) => t, - Err(e) => { - warn!("Could not parse PDU, ignoring: {e}"); - continue; - }, + match self + .services + .rooms + .event_handler + .parse_incoming_pdu(&pdu) + .await + { + Ok(t) => t, + Err(e) => { + warn!("Could not parse PDU, ignoring: {e}"); + continue; }, - ); + }; } - info!("Fetching required signing keys for all the state events we got"); - self.services - .server_keys - .fetch_required_signing_keys(events.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) - .await?; - info!("Going through room_state response PDUs"); - for result in remote_state_response - .pdus - .iter() - .map(|pdu| validate_and_add_event_id(self.services, pdu, &room_version, &pub_key_map)) - { + for result in remote_state_response.pdus.iter().map(|pdu| { + self.services + .server_keys + .validate_and_add_event_id(pdu, &room_version) + }) { let Ok((event_id, value)) = result.await else { continue; }; @@ -630,11 +608,11 @@ pub(super) async fn force_set_room_state_from_server( } info!("Going through auth_chain response"); - for result in remote_state_response - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(self.services, pdu, &room_version, &pub_key_map)) - { + for result in remote_state_response.auth_chain.iter().map(|pdu| { + self.services + .server_keys + .validate_and_add_event_id(pdu, &room_version) + }) { let Ok((event_id, value)) = result.await else { continue; }; @@ -686,10 +664,33 @@ pub(super) async fn force_set_room_state_from_server( #[admin_command] pub(super) async fn get_signing_keys( - &self, server_name: Option>, _cached: bool, + &self, server_name: Option>, notary: Option>, query: bool, ) -> Result { let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); - let signing_keys = self.services.globals.signing_keys_for(&server_name).await?; + + if let Some(notary) = notary { + let signing_keys = self + .services + .server_keys + .notary_request(¬ary, &server_name) + .await?; + + return Ok(RoomMessageEventContent::notice_markdown(format!( + "```rs\n{signing_keys:#?}\n```" + ))); + } + + let signing_keys = if query { + self.services + .server_keys + .server_request(&server_name) + .await? + } else { + self.services + .server_keys + .signing_keys_for(&server_name) + .await? + }; Ok(RoomMessageEventContent::notice_markdown(format!( "```rs\n{signing_keys:#?}\n```" @@ -697,34 +698,20 @@ pub(super) async fn get_signing_keys( } #[admin_command] -#[allow(dead_code)] -pub(super) async fn get_verify_keys( - &self, server_name: Option>, cached: bool, -) -> Result { +pub(super) async fn get_verify_keys(&self, server_name: Option>) -> Result { let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); - let mut out = String::new(); - - if cached { - writeln!(out, "| Key ID | VerifyKey |")?; - writeln!(out, "| --- | --- |")?; - for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name).await? { - writeln!(out, "| {key_id} | {verify_key:?} |")?; - } - - return Ok(RoomMessageEventContent::notice_markdown(out)); - } - let signature_ids: Vec = Vec::new(); let keys = self .services .server_keys - .fetch_signing_keys_for_server(&server_name, signature_ids) - .await?; + .verify_keys_for(&server_name) + .await; + let mut out = String::new(); writeln!(out, "| Key ID | Public Key |")?; writeln!(out, "| --- | --- |")?; for (key_id, key) in keys { - writeln!(out, "| {key_id} | {key} |")?; + writeln!(out, "| {key_id} | {key:?} |")?; } Ok(RoomMessageEventContent::notice_markdown(out)) diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index 20ddbf2f6..b74e9c36c 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -80,8 +80,16 @@ pub(super) enum DebugCommand { GetSigningKeys { server_name: Option>, + #[arg(long)] + notary: Option>, + #[arg(short, long)] - cached: bool, + query: bool, + }, + + /// - Get and display signing keys from local cache or remote server. + GetVerifyKeys { + server_name: Option>, }, /// - Sends a federation request to the remote server's @@ -119,6 +127,13 @@ pub(super) enum DebugCommand { /// the command. VerifyJson, + /// - Verify PDU + /// + /// This re-verifies a PDU existing in the database found by ID. + VerifyPdu { + event_id: Box, + }, + /// - Prints the very first PDU in the specified room (typically /// m.room.create) FirstPduInRoom { diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 150a213cd..837d34e6e 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -13,8 +13,6 @@ pub(crate) enum GlobalsCommand { LastCheckForUpdatesId, - LoadKeypair, - /// - This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. SigningKeysFor { @@ -54,20 +52,11 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - GlobalsCommand::LoadKeypair => { - let timer = tokio::time::Instant::now(); - let results = services.globals.db.load_keypair(); - let query_time = timer.elapsed(); - - Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" - ))) - }, GlobalsCommand::SigningKeysFor { origin, } => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.verify_keys_for(&origin).await; + let results = services.server_keys.verify_keys_for(&origin).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index a7a5b1668..2fa34ff7b 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1,17 +1,16 @@ use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, net::IpAddr, sync::Arc, - time::Instant, }; use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{ - debug, debug_error, debug_warn, err, error, info, + debug, debug_info, debug_warn, err, error, info, pdu, pdu::{gen_event_id_canonical_json, PduBuilder}, trace, utils, - utils::{math::continue_exponential_backoff_secs, IterStream, ReadyExt}, + utils::{IterStream, ReadyExt}, warn, Err, Error, PduEvent, Result, }; use futures::{FutureExt, StreamExt}; @@ -36,13 +35,10 @@ use ruma::{ }, StateEventType, }, - serde::Base64, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, - OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, + state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, + RoomVersionId, ServerName, UserId, }; -use serde_json::value::RawValue as RawJsonValue; use service::{appservice::RegistrationInfo, rooms::state::RoomMutexGuard, Services}; -use tokio::sync::RwLock; use crate::{client::full_user_deactivate, Ruma}; @@ -670,20 +666,22 @@ pub async fn join_room_by_id_helper( if local_join { join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) .boxed() - .await + .await?; } else { // Ask a remote server if we are not participating in this room join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) .boxed() - .await + .await?; } + + Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_remote")] async fn join_room_by_id_helper_remote( services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, -) -> Result { +) -> Result { info!("Joining {room_id} over federation."); let (make_join_response, remote_server) = make_join_request(services, sender_user, room_id, servers).await?; @@ -751,43 +749,33 @@ async fn join_room_by_id_helper_remote( // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + services + .server_keys + .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); + let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?; // Add event_id back - join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); // It has enough fields to be called a proper event now let mut join_event = join_event_stub; info!("Asking {remote_server} for send_join in room {room_id}"); + let send_join_request = federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.clone(), + omit_members: false, + pdu: services + .sending + .convert_to_outgoing_federation_event(join_event.clone()) + .await, + }; + let send_join_response = services .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - omit_members: false, - pdu: services - .sending - .convert_to_outgoing_federation_event(join_event.clone()) - .await, - }, - ) + .send_federation_request(&remote_server, send_join_request) .await?; info!("send_join finished"); @@ -805,7 +793,7 @@ async fn join_room_by_id_helper_remote( // validate and send signatures _ => { if let Some(signed_raw) = &send_join_response.room_state.event { - info!( + debug_info!( "There is a signed event. This room is probably using restricted joins. Adding signature to \ our event" ); @@ -862,25 +850,25 @@ async fn join_room_by_id_helper_remote( .await; info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) + let parsed_join_pdu = PduEvent::from_id_val(&event_id, join_event.clone()) .map_err(|e| err!(BadServerResponse("Invalid join event PDU: {e:?}")))?; - let mut state = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); - - info!("Fetching join signing keys"); + info!("Acquiring server signing keys for response events"); + let resp_events = &send_join_response.room_state; + let resp_state = &resp_events.state; + let resp_auth = &resp_events.auth_chain; services .server_keys - .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) - .await?; + .acquire_events_pubkeys(resp_auth.iter().chain(resp_state.iter())) + .await; info!("Going through send_join response room_state"); - for result in send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(services, pdu, &room_version_id, &pub_key_map)) - { + let mut state = HashMap::new(); + for result in send_join_response.room_state.state.iter().map(|pdu| { + services + .server_keys + .validate_and_add_event_id(pdu, &room_version_id) + }) { let Ok((event_id, value)) = result.await else { continue; }; @@ -902,12 +890,11 @@ async fn join_room_by_id_helper_remote( } info!("Going through send_join response auth_chain"); - for result in send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(services, pdu, &room_version_id, &pub_key_map)) - { + for result in send_join_response.room_state.auth_chain.iter().map(|pdu| { + services + .server_keys + .validate_and_add_event_id(pdu, &room_version_id) + }) { let Ok((event_id, value)) = result.await else { continue; }; @@ -937,29 +924,22 @@ async fn join_room_by_id_helper_remote( return Err!(Request(Forbidden("Auth check failed"))); } - info!("Saving state from send_join"); + info!("Compressing state from send_join"); + let compressed = state + .iter() + .stream() + .then(|(&k, id)| services.rooms.state_compressor.compress_state_event(k, id)) + .collect() + .await; + + debug!("Saving compressed state"); let (statehash_before_join, new, removed) = services .rooms .state_compressor - .save_state( - room_id, - Arc::new( - state - .into_iter() - .stream() - .then(|(k, id)| async move { - services - .rooms - .state_compressor - .compress_state_event(k, &id) - .await - }) - .collect() - .await, - ), - ) + .save_state(room_id, Arc::new(compressed)) .await?; + debug!("Forcing state for new room"); services .rooms .state @@ -1002,14 +982,14 @@ async fn join_room_by_id_helper_remote( .state .set_room_state(room_id, statehash_after_join, &state_lock); - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) + Ok(()) } #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_local")] async fn join_room_by_id_helper_local( services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, -) -> Result { +) -> Result { debug!("We can join locally"); let join_rules_event_content = services @@ -1089,7 +1069,7 @@ async fn join_room_by_id_helper_local( ) .await { - Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), + Ok(_) => return Ok(()), Err(e) => e, }; @@ -1159,24 +1139,15 @@ async fn join_room_by_id_helper_local( // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + services + .server_keys + .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); + let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?; // Add event_id back - join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); // It has enough fields to be called a proper event now let join_event = join_event_stub; @@ -1187,7 +1158,7 @@ async fn join_room_by_id_helper_local( &remote_server, federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), - event_id: event_id.to_owned(), + event_id: event_id.clone(), omit_members: false, pdu: services .sending @@ -1214,15 +1185,10 @@ async fn join_room_by_id_helper_local( } drop(state_lock); - let pub_key_map = RwLock::new(BTreeMap::new()); - services - .server_keys - .fetch_required_signing_keys([&signed_value], &pub_key_map) - .await?; services .rooms .event_handler - .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true, &pub_key_map) + .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true) .await?; } else { return Err(error); @@ -1231,7 +1197,7 @@ async fn join_room_by_id_helper_local( return Err(error); } - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) + Ok(()) } async fn make_join_request( @@ -1301,62 +1267,6 @@ async fn make_join_request( make_join_response_and_server } -pub async fn validate_and_add_event_id( - services: &Services, pdu: &RawJsonValue, room_version: &RoomVersionId, - pub_key_map: &RwLock>>, -) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()) - .map_err(|e| err!(BadServerResponse(debug_error!("Invalid PDU in server response: {e:?}"))))?; - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - let back_off = |id| async { - match services - .globals - .bad_event_ratelimiter - .write() - .expect("locked") - .entry(id) - { - Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)); - }, - } - }; - - if let Some((time, tries)) = services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&event_id) - { - // Exponential backoff - const MIN: u64 = 60 * 5; - const MAX: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN, MAX, time.elapsed(), *tries) { - return Err!(BadServerResponse("bad event {event_id:?}, still backing off")); - } - } - - if let Err(e) = ruma::signatures::verify_event(&*pub_key_map.read().await, &value, room_version) { - debug_error!("Event {event_id} failed verification {pdu:#?}"); - let e = Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}"))); - back_off(event_id).await; - return e; - } - - value.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - - Ok((event_id, value)) -} - pub(crate) async fn invite_helper( services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option, is_direct: bool, @@ -1423,8 +1333,6 @@ pub(crate) async fn invite_helper( ) .await?; - let pub_key_map = RwLock::new(BTreeMap::new()); - // We do not add the event_id field to the pdu here because of signature and // hashes checks let Ok((event_id, value)) = gen_event_id_canonical_json(&response.event, &room_version_id) else { @@ -1452,15 +1360,10 @@ pub(crate) async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - let pdu_id: Vec = services .rooms .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map) + .handle_incoming_pdu(&origin, room_id, &event_id, value, true) .await? .ok_or(Error::BadRequest( ErrorKind::InvalidParam, @@ -1714,24 +1617,15 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut leave_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + services + .server_keys + .hash_and_sign_event(&mut leave_event_stub, &room_version_id)?; // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); + let event_id = pdu::gen_event_id(&leave_event_stub, &room_version_id)?; // Add event_id back - leave_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + leave_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); // It has enough fields to be called a proper event now let leave_event = leave_event_stub; diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index 4b7b64b91..2928be87b 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -52,7 +52,7 @@ pub(super) use keys::*; pub(super) use media::*; pub(super) use media_legacy::*; pub(super) use membership::*; -pub use membership::{join_room_by_id_helper, leave_all_rooms, leave_room, validate_and_add_event_id}; +pub use membership::{join_room_by_id_helper, leave_all_rooms, leave_room}; pub(super) use message::*; pub(super) use openid::*; pub(super) use presence::*; diff --git a/src/api/router/args.rs b/src/api/router/args.rs index 7381a55f5..746e1cfc6 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -48,7 +48,7 @@ where async fn from_request(request: hyper::Request, services: &State) -> Result { let mut request = request::from(services, request).await?; let mut json_body = serde_json::from_slice::(&request.body).ok(); - let auth = auth::auth(services, &mut request, &json_body, &T::METADATA).await?; + let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?; Ok(Self { body: make_body::(services, &mut request, &mut json_body, &auth)?, origin: auth.origin, diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 8d76b4be8..6b90c5ff9 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -1,19 +1,20 @@ -use std::collections::BTreeMap; - use axum::RequestPartsExt; use axum_extra::{ headers::{authorization::Bearer, Authorization}, typed_header::TypedHeaderRejectionReason, TypedHeader, }; -use conduit::{debug_info, warn, Err, Error, Result}; +use conduit::{debug_error, err, warn, Err, Error, Result}; use http::uri::PathAndQuery; use ruma::{ api::{client::error::ErrorKind, AuthScheme, Metadata}, server_util::authorization::XMatrix, - CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, + CanonicalJsonObject, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, +}; +use service::{ + server_keys::{PubKeyMap, PubKeys}, + Services, }; -use service::Services; use super::request::Request; use crate::service::appservice::RegistrationInfo; @@ -33,7 +34,7 @@ pub(super) struct Auth { } pub(super) async fn auth( - services: &Services, request: &mut Request, json_body: &Option, metadata: &Metadata, + services: &Services, request: &mut Request, json_body: Option<&CanonicalJsonValue>, metadata: &Metadata, ) -> Result { let bearer: Option>> = request.parts.extract().await?; let token = match &bearer { @@ -151,27 +152,24 @@ pub(super) async fn auth( } async fn auth_appservice(services: &Services, request: &Request, info: Box) -> Result { - let user_id = request + let user_id_default = + || UserId::parse_with_server_name(info.registration.sender_localpart.as_str(), services.globals.server_name()); + + let Ok(user_id) = request .query .user_id .clone() - .map_or_else( - || { - UserId::parse_with_server_name( - info.registration.sender_localpart.as_str(), - services.globals.server_name(), - ) - }, - UserId::parse, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; + .map_or_else(user_id_default, UserId::parse) + else { + return Err!(Request(InvalidUsername("Username is invalid."))); + }; if !info.is_user_match(&user_id) { - return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace.")); + return Err!(Request(Exclusive("User is not in namespace."))); } if !services.users.exists(&user_id).await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist.")); + return Err!(Request(Forbidden("User does not exist."))); } Ok(Auth { @@ -182,118 +180,104 @@ async fn auth_appservice(services: &Services, request: &Request, info: Box, -) -> Result { - if !services.server.config.allow_federation { - return Err!(Config("allow_federation", "Federation is disabled.")); - } +async fn auth_server(services: &Services, request: &mut Request, body: Option<&CanonicalJsonValue>) -> Result { + type Member = (String, CanonicalJsonValue); + type Object = CanonicalJsonObject; + type Value = CanonicalJsonValue; - let TypedHeader(Authorization(x_matrix)) = request + let x_matrix = parse_x_matrix(request).await?; + auth_server_checks(services, &x_matrix)?; + + let destination = services.globals.server_name(); + let origin = &x_matrix.origin; + #[allow(clippy::or_fun_call)] + let signature_uri = request .parts - .extract::>>() + .uri + .path_and_query() + .unwrap_or(&PathAndQuery::from_static("/")) + .to_string(); + + let signature: [Member; 1] = [(x_matrix.key.to_string(), Value::String(x_matrix.sig.to_string()))]; + let signatures: [Member; 1] = [(origin.to_string(), Value::Object(signature.into()))]; + let authorization: [Member; 5] = [ + ("destination".into(), Value::String(destination.into())), + ("method".into(), Value::String(request.parts.method.to_string())), + ("origin".into(), Value::String(origin.to_string())), + ("signatures".into(), Value::Object(signatures.into())), + ("uri".into(), Value::String(signature_uri)), + ]; + + let mut authorization: Object = authorization.into(); + if let Some(body) = body { + authorization.insert("content".to_owned(), body.clone()); + } + + let key = services + .server_keys + .get_verify_key(origin, &x_matrix.key) .await - .map_err(|e| { - warn!("Missing or invalid Authorization header: {e}"); + .map_err(|e| err!(Request(Forbidden(warn!("Failed to fetch signing keys: {e}")))))?; - let msg = match e.reason() { - TypedHeaderRejectionReason::Missing => "Missing Authorization header.", - TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", - _ => "Unknown header-related error", - }; + let keys: PubKeys = [(x_matrix.key.to_string(), key.key)].into(); + let keys: PubKeyMap = [(origin.to_string(), keys)].into(); + if let Err(e) = ruma::signatures::verify_json(&keys, authorization) { + debug_error!("Failed to verify federation request from {origin}: {e}"); + if request.parts.uri.to_string().contains('@') { + warn!( + "Request uri contained '@' character. Make sure your reverse proxy gives Conduit the raw uri (apache: \ + use nocanon)" + ); + } - Error::BadRequest(ErrorKind::forbidden(), msg) - })?; + return Err!(Request(Forbidden("Failed to verify X-Matrix signatures."))); + } - let origin = &x_matrix.origin; + Ok(Auth { + origin: origin.to_owned().into(), + sender_user: None, + sender_device: None, + appservice_info: None, + }) +} + +fn auth_server_checks(services: &Services, x_matrix: &XMatrix) -> Result<()> { + if !services.server.config.allow_federation { + return Err!(Config("allow_federation", "Federation is disabled.")); + } + + let destination = services.globals.server_name(); + if x_matrix.destination.as_deref() != Some(destination) { + return Err!(Request(Forbidden("Invalid destination."))); + } + let origin = &x_matrix.origin; if services .server .config .forbidden_remote_server_names .contains(origin) { - debug_info!("Refusing to accept inbound federation request to {origin}"); - return Err!(Request(Forbidden("Federation with this homeserver is not allowed."))); + return Err!(Request(Forbidden(debug_warn!("Federation requests from {origin} denied.")))); } - let signatures = - BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig.to_string()))]); - let signatures = BTreeMap::from_iter([( - origin.as_str().to_owned(), - CanonicalJsonValue::Object( - signatures - .into_iter() - .map(|(k, v)| (k.to_string(), v)) - .collect(), - ), - )]); - - let server_destination = services.globals.server_name().as_str().to_owned(); - if let Some(destination) = x_matrix.destination.as_ref() { - if destination != &server_destination { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization.")); - } - } - - #[allow(clippy::or_fun_call)] - let signature_uri = CanonicalJsonValue::String( - request - .parts - .uri - .path_and_query() - .unwrap_or(&PathAndQuery::from_static("/")) - .to_string(), - ); - - let mut request_map = BTreeMap::from_iter([ - ( - "method".to_owned(), - CanonicalJsonValue::String(request.parts.method.to_string()), - ), - ("uri".to_owned(), signature_uri), - ("origin".to_owned(), CanonicalJsonValue::String(origin.as_str().to_owned())), - ("destination".to_owned(), CanonicalJsonValue::String(server_destination)), - ("signatures".to_owned(), CanonicalJsonValue::Object(signatures)), - ]); - - if let Some(json_body) = json_body { - request_map.insert("content".to_owned(), json_body.clone()); - }; - - let keys_result = services - .server_keys - .fetch_signing_keys_for_server(origin, vec![x_matrix.key.to_string()]) - .await; - - let keys = keys_result.map_err(|e| { - warn!("Failed to fetch signing keys: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Failed to fetch signing keys.") - })?; - - let pub_key_map = BTreeMap::from_iter([(origin.as_str().to_owned(), keys)]); + Ok(()) +} - match ruma::signatures::verify_json(&pub_key_map, &request_map) { - Ok(()) => Ok(Auth { - origin: Some(origin.clone()), - sender_user: None, - sender_device: None, - appservice_info: None, - }), - Err(e) => { - warn!("Failed to verify json request from {origin}: {e}\n{request_map:?}"); +async fn parse_x_matrix(request: &mut Request) -> Result { + let TypedHeader(Authorization(x_matrix)) = request + .parts + .extract::>>() + .await + .map_err(|e| { + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => "Missing Authorization header.", + TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", + _ => "Unknown header-related error", + }; - if request.parts.uri.to_string().contains('@') { - warn!( - "Request uri contained '@' character. Make sure your reverse proxy gives Conduit the raw uri \ - (apache: use nocanon)" - ); - } + err!(Request(Forbidden(warn!("{msg}: {e}")))) + })?; - Err(Error::BadRequest( - ErrorKind::forbidden(), - "Failed to verify X-Matrix signatures.", - )) - }, - } + Ok(x_matrix) } diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index f02655e65..a9e404c52 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -85,13 +85,10 @@ pub(crate) async fn create_invite_route( .acl_check(invited_user.server_name(), &body.room_id) .await?; - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut signed_event, - &body.room_version, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + services + .server_keys + .hash_and_sign_event(&mut signed_event, &body.room_version) + .map_err(|e| err!(Request(InvalidParam("Failed to sign event: {e}"))))?; // Generate event id let event_id = EventId::parse(format!( diff --git a/src/api/server/key.rs b/src/api/server/key.rs index 686e44242..3913ce43f 100644 --- a/src/api/server/key.rs +++ b/src/api/server/key.rs @@ -1,20 +1,16 @@ -use std::{ - collections::BTreeMap, - time::{Duration, SystemTime}, -}; +use std::{collections::BTreeMap, time::Duration}; use axum::{extract::State, response::IntoResponse, Json}; +use conduit::{utils::timepoint_from_now, Result}; use ruma::{ api::{ - federation::discovery::{get_server_keys, ServerSigningKeys, VerifyKey}, + federation::discovery::{get_server_keys, ServerSigningKeys}, OutgoingResponse, }, - serde::{Base64, Raw}, - MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, + serde::Raw, + MilliSecondsSinceUnixEpoch, }; -use crate::Result; - /// # `GET /_matrix/key/v2/server` /// /// Gets the public signing keys of this server. @@ -24,47 +20,33 @@ use crate::Result; // Response type for this endpoint is Json because we need to calculate a // signature for the response pub(crate) async fn get_server_keys_route(State(services): State) -> Result { - let verify_keys: BTreeMap = BTreeMap::from([( - format!("ed25519:{}", services.globals.keypair().version()) - .try_into() - .expect("found invalid server signing keys in DB"), - VerifyKey { - key: Base64::new(services.globals.keypair().public_key().to_vec()), - }, - )]); + let server_name = services.globals.server_name(); + let verify_keys = services.server_keys.verify_keys_for(server_name).await; + let server_key = ServerSigningKeys { + verify_keys, + server_name: server_name.to_owned(), + valid_until_ts: valid_until_ts(), + old_verify_keys: BTreeMap::new(), + signatures: BTreeMap::new(), + }; - let mut response = serde_json::from_slice( - get_server_keys::v2::Response { - server_key: Raw::new(&ServerSigningKeys { - server_name: services.globals.server_name().to_owned(), - verify_keys, - old_verify_keys: BTreeMap::new(), - signatures: BTreeMap::new(), - valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(86400 * 7)) - .expect("valid_until_ts should not get this high"), - ) - .expect("time is valid"), - }) - .expect("static conversion, no errors"), - } - .try_into_http_response::>() - .unwrap() - .body(), - ) - .unwrap(); + let response = get_server_keys::v2::Response { + server_key: Raw::new(&server_key)?, + } + .try_into_http_response::>()?; - ruma::signatures::sign_json( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut response, - ) - .unwrap(); + let mut response = serde_json::from_slice(response.body())?; + services.server_keys.sign_json(&mut response)?; Ok(Json(response)) } +fn valid_until_ts() -> MilliSecondsSinceUnixEpoch { + let dur = Duration::from_secs(86400 * 7); + let timepoint = timepoint_from_now(dur).expect("SystemTime should not overflow"); + MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow") +} + /// # `GET /_matrix/key/v2/server/{keyId}` /// /// Gets the public signing keys of this server. diff --git a/src/api/server/send.rs b/src/api/server/send.rs index f6916ccfa..40f9403b2 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -21,7 +21,6 @@ use ruma::{ OwnedEventId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::RwLock; use crate::{ services::Services, @@ -109,22 +108,6 @@ async fn handle_pdus( // and hashes checks } - // We go through all the signatures we see on the PDUs and fetch the - // corresponding signing keys - let pub_key_map = RwLock::new(BTreeMap::new()); - if !parsed_pdus.is_empty() { - services - .server_keys - .fetch_required_signing_keys(parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) - .await - .unwrap_or_else(|e| warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}")); - - debug!( - elapsed = ?txn_start_time.elapsed(), - "Fetched signing keys" - ); - } - let mut resolved_map = BTreeMap::new(); for (event_id, value, room_id) in parsed_pdus { let pdu_start_time = Instant::now(); @@ -134,17 +117,18 @@ async fn handle_pdus( .mutex_federation .lock(&room_id) .await; + resolved_map.insert( event_id.clone(), services .rooms .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, true, &pub_key_map) + .handle_incoming_pdu(origin, &room_id, &event_id, value, true) .await .map(|_| ()), ); - drop(mutex_lock); + drop(mutex_lock); debug!( pdu_elapsed = ?pdu_start_time.elapsed(), txn_elapsed = ?txn_start_time.elapsed(), diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index f92576904..d888d75e8 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::{borrow::Borrow, collections::BTreeMap}; +use std::borrow::Borrow; use axum::extract::State; use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; @@ -15,7 +15,6 @@ use ruma::{ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::Services; -use tokio::sync::RwLock; use crate::Ruma; @@ -43,9 +42,6 @@ async fn create_join_event( .await .map_err(|_| err!(Request(NotFound("Event state not found."))))?; - let pub_key_map = RwLock::new(BTreeMap::new()); - // let mut auth_cache = EventMap::new(); - // We do not add the event_id field to the pdu here because of signature and // hashes checks let room_version_id = services.rooms.state.get_room_version(room_id).await?; @@ -137,20 +133,12 @@ async fn create_join_event( .await .unwrap_or_default() { - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut value, - &room_version_id, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + services + .server_keys + .hash_and_sign_event(&mut value, &room_version_id) + .map_err(|e| err!(Request(InvalidParam("Failed to sign event: {e}"))))?; } - services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - let origin: OwnedServerName = serde_json::from_value( serde_json::to_value( value @@ -171,7 +159,7 @@ async fn create_join_event( let pdu_id: Vec = services .rooms .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) + .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true) .await? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index 81f41af07..0530f9dd5 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -1,7 +1,5 @@ #![allow(deprecated)] -use std::collections::BTreeMap; - use axum::extract::State; use conduit::{utils::ReadyExt, Error, Result}; use ruma::{ @@ -13,7 +11,6 @@ use ruma::{ OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::RwLock; use crate::{ service::{pdu::gen_event_id_canonical_json, Services}, @@ -60,8 +57,6 @@ async fn create_leave_event( .acl_check(origin, room_id) .await?; - let pub_key_map = RwLock::new(BTreeMap::new()); - // We do not add the event_id field to the pdu here because of signature and // hashes checks let room_version_id = services.rooms.state.get_room_version(room_id).await?; @@ -154,21 +149,17 @@ async fn create_leave_event( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; - services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - let mutex_lock = services .rooms .event_handler .mutex_federation .lock(room_id) .await; + let pdu_id: Vec = services .rooms .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map) + .handle_incoming_pdu(&origin, room_id, &event_id, value, true) .await? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index b5e07da23..114c6e766 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -490,30 +490,6 @@ pub struct Config { #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec, - /// Option to control whether conduwuit will query your list of trusted - /// notary key servers (`trusted_servers`) for remote homeserver signing - /// keys it doesn't know *first*, or query the individual servers first - /// before falling back to the trusted key servers. - /// - /// The former/default behaviour makes federated/remote rooms joins - /// generally faster because we're querying a single (or list of) server - /// that we know works, is reasonably fast, and is reliable for just about - /// all the homeserver signing keys in the room. Querying individual - /// servers may take longer depending on the general infrastructure of - /// everyone in there, how many dead servers there are, etc. - /// - /// However, this does create an increased reliance on one single or - /// multiple large entities as `trusted_servers` should generally - /// contain long-term and large servers who know a very large number of - /// homeservers. - /// - /// If you don't know what any of this means, leave this and - /// `trusted_servers` alone to their defaults. - /// - /// Defaults to true as this is the fastest option for federation. - #[serde(default = "true_fn")] - pub query_trusted_key_servers_first: bool, - /// max log level for conduwuit. allows debug, info, warn, or error /// see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives /// **Caveat**: @@ -1518,10 +1494,6 @@ impl fmt::Display for Config { .map(|server| server.host()) .join(", "), ); - line( - "Query Trusted Key Servers First", - &self.query_trusted_key_servers_first.to_string(), - ); line("OpenID Token TTL", &self.openid_token_ttl.to_string()); line( "TURN username", diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 39fa43404..42250a0c6 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -85,6 +85,8 @@ pub enum Error { BadRequest(ruma::api::client::error::ErrorKind, &'static str), //TODO: remove #[error("{0}")] BadServerResponse(Cow<'static, str>), + #[error(transparent)] + CanonicalJson(#[from] ruma::CanonicalJsonError), #[error("There was a problem with the '{0}' directive in your configuration: {1}")] Config(&'static str, Cow<'static, str>), #[error("{0}")] @@ -110,6 +112,8 @@ pub enum Error { #[error(transparent)] Ruma(#[from] ruma::api::client::error::Error), #[error(transparent)] + Signatures(#[from] ruma::signatures::Error), + #[error(transparent)] StateRes(#[from] ruma::state_res::Error), #[error("uiaa")] Uiaa(ruma::api::client::uiaa::UiaaInfo), diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 5f50fe5b1..274b96bd2 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -408,10 +408,13 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { - json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { + json.insert("event_id".into(), CanonicalJsonValue::String(event_id.into())); - serde_json::from_value(serde_json::to_value(json).expect("valid JSON")) + let value = serde_json::to_value(json)?; + let pdu = serde_json::from_value(value)?; + + Ok(pdu) } } @@ -462,13 +465,15 @@ pub fn gen_event_id_canonical_json( let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; - let event_id = format!( - "${}", - // Anything higher than version3 behaves the same - ruma::signatures::reference_hash(&value, room_version_id).expect("ruma can calculate reference hashes") - ) - .try_into() - .expect("ruma's reference hashes are valid event ids"); + let event_id = gen_event_id(&value, room_version_id)?; Ok((event_id, value)) } + +/// Generates a correct eventId for the incoming pdu. +pub fn gen_event_id(value: &CanonicalJsonObject, room_version_id: &RoomVersionId) -> Result { + let reference_hash = ruma::signatures::reference_hash(value, room_version_id)?; + let event_id: OwnedEventId = format!("${reference_hash}").try_into()?; + + Ok(event_id) +} diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 3638cb56c..eea7597a0 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,16 +1,9 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; +use std::sync::{Arc, RwLock}; -use conduit::{trace, utils, utils::rand, Error, Result, Server}; -use database::{Database, Deserialized, Json, Map}; +use conduit::{trace, utils, Result, Server}; +use database::{Database, Deserialized, Map}; use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; -use ruma::{ - api::federation::discovery::{ServerSigningKeys, VerifyKey}, - signatures::Ed25519KeyPair, - DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, -}; +use ruma::{DeviceId, UserId}; use crate::{rooms, Dep}; @@ -25,7 +18,6 @@ pub struct Data { pduid_pdu: Arc, keychangeid_userid: Arc, roomusertype_roomuserdataid: Arc, - server_signingkeys: Arc, readreceiptid_readreceipt: Arc, userid_lastonetimekeyupdate: Arc, counter: RwLock, @@ -56,7 +48,6 @@ impl Data { pduid_pdu: db["pduid_pdu"].clone(), keychangeid_userid: db["keychangeid_userid"].clone(), roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), - server_signingkeys: db["server_signingkeys"].clone(), readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")), @@ -205,107 +196,6 @@ impl Data { Ok(()) } - pub fn load_keypair(&self) -> Result { - let generate = |_| { - let keypair = Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"); - - let mut value = rand::string(8).as_bytes().to_vec(); - value.push(0xFF); - value.extend_from_slice(&keypair); - - self.global.insert(b"keypair", &value); - value - }; - - let keypair_bytes: Vec = self - .global - .get_blocking(b"keypair") - .map_or_else(generate, Into::into); - - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); - utils::string_from_bytes( - // 1. version - parts - .next() - .expect("splitn always returns at least one element"), - ) - .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) - .and_then(|version| { - // 2. key - parts - .next() - .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) - .map(|key| (version, key)) - }) - .and_then(|(version, key)| { - Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) - }) - } - - #[inline] - pub fn remove_keypair(&self) -> Result<()> { - self.global.remove(b"keypair"); - Ok(()) - } - - /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored - /// in room version > 4 - /// - /// Remove the outdated keys and insert the new ones. - /// - /// This doesn't actually check that the keys provided are newer than the - /// old set. - pub async fn add_signing_key( - &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> BTreeMap { - // (timo) Not atomic, but this is not critical - let mut keys: ServerSigningKeys = self - .server_signingkeys - .get(origin) - .await - .deserialized() - .unwrap_or_else(|_| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); - - keys.verify_keys.extend(new_keys.verify_keys); - keys.old_verify_keys.extend(new_keys.old_verify_keys); - - self.server_signingkeys.raw_put(origin, Json(&keys)); - - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - - tree - } - - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found - /// for the server. - pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { - self.signing_keys_for(origin).await.map_or_else( - |_| Ok(BTreeMap::new()), - |keys: ServerSigningKeys| { - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - Ok(tree) - }, - ) - } - - pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { - self.server_signingkeys.get(origin).await.deserialized() - } - pub async fn database_version(&self) -> u64 { self.global .get(b"version") diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index fb970f078..7680007d4 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -2,7 +2,7 @@ mod data; pub(super) mod migrations; use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, fmt::Write, sync::{Arc, RwLock}, time::Instant, @@ -13,13 +13,8 @@ use data::Data; use ipaddress::IPAddress; use regex::RegexSet; use ruma::{ - api::{ - client::discovery::discover_support::ContactRole, - federation::discovery::{ServerSigningKeys, VerifyKey}, - }, - serde::Base64, - DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomAliasId, - RoomVersionId, ServerName, UserId, + api::client::discovery::discover_support::ContactRole, DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, + OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId, }; use tokio::sync::Mutex; use url::Url; @@ -31,7 +26,6 @@ pub struct Service { pub config: Config, pub cidr_range_denylist: Vec, - keypair: Arc, jwt_decoding_key: Option, pub stable_room_versions: Vec, pub unstable_room_versions: Vec, @@ -50,16 +44,6 @@ impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let db = Data::new(&args); let config = &args.server.config; - let keypair = db.load_keypair(); - - let keypair = match keypair { - Ok(k) => k, - Err(e) => { - error!("Keypair invalid. Deleting..."); - db.remove_keypair()?; - return Err(e); - }, - }; let jwt_decoding_key = config .jwt_secret @@ -115,7 +99,6 @@ impl crate::Service for Service { db, config: config.clone(), cidr_range_denylist, - keypair: Arc::new(keypair), jwt_decoding_key, stable_room_versions, unstable_room_versions, @@ -175,9 +158,6 @@ impl crate::Service for Service { } impl Service { - /// Returns this server's keypair. - pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } - #[inline] pub fn next_count(&self) -> Result { self.db.next_count() } @@ -224,8 +204,6 @@ impl Service { pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers } - pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first } - pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } pub fn turn_password(&self) -> &String { &self.config.turn_password } @@ -302,28 +280,6 @@ impl Service { } } - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found - /// for the server. - pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { - let mut keys = self.db.verify_keys_for(origin).await?; - if origin == self.server_name() { - keys.insert( - format!("ed25519:{}", self.keypair().version()) - .try_into() - .expect("found invalid server signing keys in DB"), - VerifyKey { - key: Base64::new(self.keypair.public_key().to_vec()), - }, - ); - } - - Ok(keys) - } - - pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { - self.db.signing_keys_for(origin).await - } - pub fn well_known_client(&self) -> &Option { &self.config.well_known.client } pub fn well_known_server(&self) -> &Option { &self.config.well_known.server } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index f8042b67b..8448404ba 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -28,12 +28,10 @@ use ruma::{ StateEventType, TimelineEventType, }, int, - serde::Base64, state_res::{self, EventTypeExt, RoomVersion, StateMap}, - uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, - ServerName, UserId, + uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, + RoomId, RoomVersionId, ServerName, UserId, }; -use tokio::sync::RwLock; use super::state_compressor::CompressedStateEvent; use crate::{globals, rooms, sending, server_keys, Dep}; @@ -129,11 +127,10 @@ impl Service { /// 13. Use state resolution to find new room state /// 14. Check if the event passes auth based on the "current state" of the /// room, if not soft fail it - #[tracing::instrument(skip(self, origin, value, is_timeline_event, pub_key_map), name = "pdu")] + #[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")] pub async fn handle_incoming_pdu<'a>( &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, value: BTreeMap, is_timeline_event: bool, - pub_key_map: &'a RwLock>>, ) -> Result>> { // 1. Skip the PDU if we already have it as a timeline event if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { @@ -177,7 +174,7 @@ impl Service { let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; let (incoming_pdu, val) = self - .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map) + .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false) .boxed() .await?; @@ -200,7 +197,6 @@ impl Service { &create_event, room_id, &room_version_id, - pub_key_map, incoming_pdu.prev_events.clone(), ) .await?; @@ -212,7 +208,6 @@ impl Service { origin, event_id, room_id, - pub_key_map, &mut eventid_info, &create_event, &first_pdu_in_room, @@ -250,7 +245,7 @@ impl Service { .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); let r = self - .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) + .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id) .await; self.federation_handletime @@ -264,12 +259,11 @@ impl Service { #[allow(clippy::type_complexity)] #[allow(clippy::too_many_arguments)] #[tracing::instrument( - skip(self, origin, event_id, room_id, pub_key_map, eventid_info, create_event, first_pdu_in_room), + skip(self, origin, event_id, room_id, eventid_info, create_event, first_pdu_in_room), name = "prev" )] pub async fn handle_prev_pdu<'a>( &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, - pub_key_map: &'a RwLock>>, eventid_info: &mut HashMap, (Arc, BTreeMap)>, create_event: &Arc, first_pdu_in_room: &Arc, prev_id: &EventId, ) -> Result<()> { @@ -318,7 +312,7 @@ impl Service { .expect("locked") .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id, pub_key_map) + self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id) .await?; self.federation_handletime @@ -338,8 +332,7 @@ impl Service { #[allow(clippy::too_many_arguments)] async fn handle_outlier_pdu<'a>( &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, - mut value: BTreeMap, auth_events_known: bool, - pub_key_map: &'a RwLock>>, + mut value: CanonicalJsonObject, auth_events_known: bool, ) -> Result<(Arc, BTreeMap)> { // 1. Remove unsigned field value.remove("unsigned"); @@ -349,14 +342,13 @@ impl Service { // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match let room_version_id = Self::get_room_version_id(create_event)?; - - let guard = pub_key_map.read().await; - let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { - Err(e) => { - // Drop - warn!("Dropping bad event {event_id}: {e}"); - return Err!(Request(InvalidParam("Signature verification failed"))); - }, + let mut val = match self + .services + .server_keys + .verify_event(&value, Some(&room_version_id)) + .await + { + Ok(ruma::signatures::Verified::All) => value, Ok(ruma::signatures::Verified::Signatures) => { // Redact debug_info!("Calculated hash does not match (redaction): {event_id}"); @@ -371,11 +363,13 @@ impl Service { obj }, - Ok(ruma::signatures::Verified::All) => value, + Err(e) => { + return Err!(Request(InvalidParam(debug_error!( + "Signature verification failed for {event_id}: {e}" + )))) + }, }; - drop(guard); - // Now that we have checked the signature and hashes we can add the eventID and // convert to our PduEvent type val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); @@ -404,7 +398,6 @@ impl Service { create_event, room_id, &room_version_id, - pub_key_map, ), ) .await; @@ -487,7 +480,7 @@ impl Service { pub async fn upgrade_outlier_to_timeline_pdu( &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, - origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>>, + origin: &ServerName, room_id: &RoomId, ) -> Result>> { // Skip the PDU if we already have it as a timeline event if let Ok(pduid) = self @@ -526,14 +519,7 @@ impl Service { if state_at_incoming_event.is_none() { state_at_incoming_event = self - .fetch_state( - origin, - create_event, - room_id, - &room_version_id, - pub_key_map, - &incoming_pdu.event_id, - ) + .fetch_state(origin, create_event, room_id, &room_version_id, &incoming_pdu.event_id) .await?; } @@ -1021,10 +1007,10 @@ impl Service { /// Call /state_ids to find out what the state at this pdu is. We trust the /// server's response to some extend (sic), but we still do a lot of checks /// on the events - #[tracing::instrument(skip(self, pub_key_map, create_event, room_version_id))] + #[tracing::instrument(skip(self, create_event, room_version_id))] async fn fetch_state( &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, - pub_key_map: &RwLock>>, event_id: &EventId, + event_id: &EventId, ) -> Result>>> { debug!("Fetching state ids"); let res = self @@ -1048,7 +1034,7 @@ impl Service { .collect::>(); let state_vec = self - .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) + .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id) .boxed() .await; @@ -1102,7 +1088,7 @@ impl Service { /// d. TODO: Ask other servers over federation? pub async fn fetch_and_handle_outliers<'a>( &self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, - room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, + room_version_id: &'a RoomVersionId, ) -> Vec<(Arc, Option>)> { let back_off = |id| match self .services @@ -1222,22 +1208,6 @@ impl Service { events_with_auth_events.push((id, None, events_in_reverse_order)); } - // We go through all the signatures we see on the PDUs and their unresolved - // dependencies and fetch the corresponding signing keys - self.services - .server_keys - .fetch_required_signing_keys( - events_with_auth_events - .iter() - .flat_map(|(_id, _local_pdu, events)| events) - .map(|(_event_id, event)| event), - pub_key_map, - ) - .await - .unwrap_or_else(|e| { - warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}"); - }); - let mut pdus = Vec::with_capacity(events_with_auth_events.len()); for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { // a. Look in the main timeline (pduid_pdu tree) @@ -1266,16 +1236,8 @@ impl Service { } } - match Box::pin(self.handle_outlier_pdu( - origin, - create_event, - &next_id, - room_id, - value.clone(), - true, - pub_key_map, - )) - .await + match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true)) + .await { Ok((pdu, json)) => { if next_id == *id { @@ -1296,7 +1258,7 @@ impl Service { #[tracing::instrument(skip_all)] async fn fetch_prev( &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, - pub_key_map: &RwLock>>, initial_set: Vec>, + initial_set: Vec>, ) -> Result<( Vec>, HashMap, (Arc, BTreeMap)>, @@ -1311,14 +1273,7 @@ impl Service { while let Some(prev_event_id) = todo_outlier_stack.pop() { if let Some((pdu, mut json_opt)) = self - .fetch_and_handle_outliers( - origin, - &[prev_event_id.clone()], - create_event, - room_id, - room_version_id, - pub_key_map, - ) + .fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id) .boxed() .await .pop() diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 21e5395da..902e50fff 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -16,7 +16,7 @@ use conduit::{ }; use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryStreamExt}; use ruma::{ - api::{client::error::ErrorKind, federation}, + api::federation, canonical_json::to_canonical_value, events::{ push_rules::PushRulesEvent, @@ -30,14 +30,12 @@ use ruma::{ GlobalAccountDataEventType, StateEventType, TimelineEventType, }, push::{Action, Ruleset, Tweak}, - serde::Base64, state_res::{self, Event, RoomVersion}, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tokio::sync::RwLock; use self::data::Data; pub use self::data::PdusIterItem; @@ -784,21 +782,15 @@ impl Service { to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"), ); - match ruma::signatures::hash_and_sign_event( - self.services.globals.server_name().as_str(), - self.services.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) { - Ok(()) => {}, - Err(e) => { - return match e { - ruma::signatures::Error::PduSize => { - Err(Error::BadRequest(ErrorKind::TooLarge, "Message is too long")) - }, - _ => Err(Error::BadRequest(ErrorKind::Unknown, "Signing event failed")), - } - }, + if let Err(e) = self + .services + .server_keys + .hash_and_sign_event(&mut pdu_json, &room_version_id) + { + return match e { + Error::Signatures(ruma::signatures::Error::PduSize) => Err!(Request(TooLarge("Message is too long"))), + _ => Err!(Request(Unknown("Signing event failed"))), + }; } // Generate event id @@ -1106,9 +1098,8 @@ impl Service { .await; match response { Ok(response) => { - let pub_key_map = RwLock::new(BTreeMap::new()); for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await { + if let Err(e) = self.backfill_pdu(backfill_server, pdu).await { warn!("Failed to add backfilled pdu in room {room_id}: {e}"); } } @@ -1124,11 +1115,8 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, pdu, pub_key_map))] - pub async fn backfill_pdu( - &self, origin: &ServerName, pdu: Box, - pub_key_map: &RwLock>>, - ) -> Result<()> { + #[tracing::instrument(skip(self, pdu))] + pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box) -> Result<()> { let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?; // Lock so we cannot backfill the same pdu twice at the same time @@ -1146,14 +1134,9 @@ impl Service { return Ok(()); } - self.services - .server_keys - .fetch_required_signing_keys([&value], pub_key_map) - .await?; - self.services .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) + .handle_incoming_pdu(origin, &room_id, &event_id, value, false) .await?; let value = self diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index e3582f2ea..5970c3836 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -17,7 +17,7 @@ use tokio::sync::Mutex; use self::data::Data; pub use self::dest::Destination; -use crate::{account_data, client, globals, presence, pusher, resolver, rooms, users, Dep}; +use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_keys, users, Dep}; pub struct Service { server: Arc, @@ -41,6 +41,7 @@ struct Services { account_data: Dep, appservice: Dep, pusher: Dep, + server_keys: Dep, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -78,6 +79,7 @@ impl crate::Service for Service { account_data: args.depend::("account_data"), appservice: args.depend::("appservice"), pusher: args.depend::("pusher"), + server_keys: args.depend::("server_keys"), }, db: Data::new(&args), sender, diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 9a8f408b5..73b6a468f 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,8 +1,8 @@ use std::{fmt::Debug, mem}; use conduit::{ - debug, debug_error, debug_info, debug_warn, err, error::inspect_debug_log, trace, utils::string::EMPTY, Err, Error, - Result, + debug, debug_error, debug_info, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, + Err, Error, Result, }; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; @@ -18,7 +18,7 @@ use ruma::{ }; use crate::{ - globals, resolver, + resolver, resolver::{actual::ActualDest, cache::CachedDest}, }; @@ -75,7 +75,7 @@ impl super::Service { .try_into_http_request::>(&actual.string, SATIR, &VERSIONS) .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; - sign_request::(&self.services.globals, dest, &mut http_request); + self.sign_request::(dest, &mut http_request); let request = Request::try_from(http_request)?; self.validate_url(request.url())?; @@ -178,7 +178,8 @@ where Err(e.into()) } -fn sign_request(globals: &globals::Service, dest: &ServerName, http_request: &mut http::Request>) +#[implement(super::Service)] +fn sign_request(&self, dest: &ServerName, http_request: &mut http::Request>) where T: OutgoingRequest + Debug + Send, { @@ -200,11 +201,13 @@ where .to_string() .into(), ); - req_map.insert("origin".to_owned(), globals.server_name().as_str().into()); + req_map.insert("origin".to_owned(), self.services.globals.server_name().to_string().into()); req_map.insert("destination".to_owned(), dest.as_str().into()); let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap"); - ruma::signatures::sign_json(globals.server_name().as_str(), globals.keypair(), &mut req_json) + self.services + .server_keys + .sign_json(&mut req_json) .expect("our request json is what ruma expects"); let req_json: serde_json::Map = @@ -231,7 +234,12 @@ where http_request.headers_mut().insert( AUTHORIZATION, - HeaderValue::from(&XMatrix::new(globals.config.server_name.clone(), dest.to_owned(), key, sig)), + HeaderValue::from(&XMatrix::new( + self.services.globals.server_name().to_owned(), + dest.to_owned(), + key, + sig, + )), ); } } diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs new file mode 100644 index 000000000..2b1700400 --- /dev/null +++ b/src/service/server_keys/acquire.rs @@ -0,0 +1,175 @@ +use std::{ + borrow::Borrow, + collections::{BTreeMap, BTreeSet}, +}; + +use conduit::{debug, debug_warn, error, implement, result::FlatOk, warn}; +use futures::{stream::FuturesUnordered, StreamExt}; +use ruma::{ + api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName, + OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, +}; +use serde_json::value::RawValue as RawJsonValue; + +use super::key_exists; + +type Batch = BTreeMap>; + +#[implement(super::Service)] +pub async fn acquire_events_pubkeys<'a, I>(&self, events: I) +where + I: Iterator> + Send, +{ + type Batch = BTreeMap>; + type Signatures = BTreeMap>; + + let mut batch = Batch::new(); + events + .cloned() + .map(Raw::::from_json) + .map(|event| event.get_field::("signatures")) + .filter_map(FlatOk::flat_ok) + .flat_map(IntoIterator::into_iter) + .for_each(|(server, sigs)| { + batch.entry(server).or_default().extend(sigs.into_keys()); + }); + + let batch = batch + .iter() + .map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow))); + + self.acquire_pubkeys(batch).await; +} + +#[implement(super::Service)] +pub async fn acquire_pubkeys<'a, S, K>(&self, batch: S) +where + S: Iterator + Send + Clone, + K: Iterator + Send + Clone, +{ + let requested_servers = batch.clone().count(); + let requested_keys = batch.clone().flat_map(|(_, key_ids)| key_ids).count(); + + debug!("acquire {requested_keys} keys from {requested_servers}"); + + let missing = self.acquire_locals(batch).await; + let missing_keys = keys_count(&missing); + let missing_servers = missing.len(); + if missing_servers == 0 { + return; + } + + debug!("missing {missing_keys} keys for {missing_servers} servers locally"); + + let missing = self.acquire_origins(missing.into_iter()).await; + let missing_keys = keys_count(&missing); + let missing_servers = missing.len(); + if missing_servers == 0 { + return; + } + + debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable"); + + let missing = self.acquire_notary(missing.into_iter()).await; + let missing_keys = keys_count(&missing); + let missing_servers = missing.len(); + if missing_keys > 0 { + debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries"); + warn!("did not obtain {missing_keys} of {requested_keys} keys; some events may not be accepted"); + } +} + +#[implement(super::Service)] +async fn acquire_locals<'a, S, K>(&self, batch: S) -> Batch +where + S: Iterator + Send, + K: Iterator + Send, +{ + let mut missing = Batch::new(); + for (server, key_ids) in batch { + for key_id in key_ids { + if !self.verify_key_exists(server, key_id).await { + missing + .entry(server.into()) + .or_default() + .push(key_id.into()); + } + } + } + + missing +} + +#[implement(super::Service)] +async fn acquire_origins(&self, batch: I) -> Batch +where + I: Iterator)> + Send, +{ + let mut requests: FuturesUnordered<_> = batch + .map(|(origin, key_ids)| self.acquire_origin(origin, key_ids)) + .collect(); + + let mut missing = Batch::new(); + while let Some((origin, key_ids)) = requests.next().await { + if !key_ids.is_empty() { + missing.insert(origin, key_ids); + } + } + + missing +} + +#[implement(super::Service)] +async fn acquire_origin( + &self, origin: OwnedServerName, mut key_ids: Vec, +) -> (OwnedServerName, Vec) { + if let Ok(server_keys) = self.server_request(&origin).await { + self.add_signing_keys(server_keys.clone()).await; + key_ids.retain(|key_id| !key_exists(&server_keys, key_id)); + } + + (origin, key_ids) +} + +#[implement(super::Service)] +async fn acquire_notary(&self, batch: I) -> Batch +where + I: Iterator)> + Send, +{ + let mut missing: Batch = batch.collect(); + for notary in self.services.globals.trusted_servers() { + let missing_keys = keys_count(&missing); + let missing_servers = missing.len(); + debug!("Asking notary {notary} for {missing_keys} missing keys from {missing_servers} servers"); + + let batch = missing + .iter() + .map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow))); + + match self.batch_notary_request(notary, batch).await { + Err(e) => error!("Failed to contact notary {notary:?}: {e}"), + Ok(results) => { + for server_keys in results { + self.acquire_notary_result(&mut missing, server_keys).await; + } + }, + } + } + + missing +} + +#[implement(super::Service)] +async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSigningKeys) { + let server = &server_keys.server_name; + self.add_signing_keys(server_keys.clone()).await; + + if let Some(key_ids) = missing.get_mut(server) { + key_ids.retain(|key_id| key_exists(&server_keys, key_id)); + if key_ids.is_empty() { + missing.remove(server); + } + } +} + +fn keys_count(batch: &Batch) -> usize { batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count() } diff --git a/src/service/server_keys/get.rs b/src/service/server_keys/get.rs new file mode 100644 index 000000000..0f449b46b --- /dev/null +++ b/src/service/server_keys/get.rs @@ -0,0 +1,86 @@ +use std::borrow::Borrow; + +use conduit::{implement, Err, Result}; +use ruma::{api::federation::discovery::VerifyKey, CanonicalJsonObject, RoomVersionId, ServerName, ServerSigningKeyId}; + +use super::{extract_key, PubKeyMap, PubKeys}; + +#[implement(super::Service)] +pub async fn get_event_keys(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> Result { + use ruma::signatures::required_keys; + + let required = match required_keys(object, version) { + Ok(required) => required, + Err(e) => return Err!(BadServerResponse("Failed to determine keys required to verify: {e}")), + }; + + let batch = required + .iter() + .map(|(s, ids)| (s.borrow(), ids.iter().map(Borrow::borrow))); + + Ok(self.get_pubkeys(batch).await) +} + +#[implement(super::Service)] +pub async fn get_pubkeys<'a, S, K>(&self, batch: S) -> PubKeyMap +where + S: Iterator + Send, + K: Iterator + Send, +{ + let mut keys = PubKeyMap::new(); + for (server, key_ids) in batch { + let pubkeys = self.get_pubkeys_for(server, key_ids).await; + keys.insert(server.into(), pubkeys); + } + + keys +} + +#[implement(super::Service)] +pub async fn get_pubkeys_for<'a, I>(&self, origin: &ServerName, key_ids: I) -> PubKeys +where + I: Iterator + Send, +{ + let mut keys = PubKeys::new(); + for key_id in key_ids { + if let Ok(verify_key) = self.get_verify_key(origin, key_id).await { + keys.insert(key_id.into(), verify_key.key); + } + } + + keys +} + +#[implement(super::Service)] +pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result { + if let Some(result) = self.verify_keys_for(origin).await.remove(key_id) { + return Ok(result); + } + + if let Ok(server_key) = self.server_request(origin).await { + self.add_signing_keys(server_key.clone()).await; + if let Some(result) = extract_key(server_key, key_id) { + return Ok(result); + } + } + + for notary in self.services.globals.trusted_servers() { + if let Ok(server_keys) = self.notary_request(notary, origin).await { + for server_key in &server_keys { + self.add_signing_keys(server_key.clone()).await; + } + + for server_key in server_keys { + if let Some(result) = extract_key(server_key, key_id) { + return Ok(result); + } + } + } + } + + Err!(BadServerResponse(debug_error!( + ?key_id, + ?origin, + "Failed to fetch federation signing-key" + ))) +} diff --git a/src/service/server_keys/keypair.rs b/src/service/server_keys/keypair.rs new file mode 100644 index 000000000..31a24cdf3 --- /dev/null +++ b/src/service/server_keys/keypair.rs @@ -0,0 +1,64 @@ +use std::sync::Arc; + +use conduit::{debug, debug_info, err, error, utils, utils::string_from_bytes, Result}; +use database::Database; +use ruma::{api::federation::discovery::VerifyKey, serde::Base64, signatures::Ed25519KeyPair}; + +use super::VerifyKeys; + +pub(super) fn init(db: &Arc) -> Result<(Box, VerifyKeys)> { + let keypair = load(db).inspect_err(|_e| { + error!("Keypair invalid. Deleting..."); + remove(db); + })?; + + let verify_key = VerifyKey { + key: Base64::new(keypair.public_key().to_vec()), + }; + + let id = format!("ed25519:{}", keypair.version()); + let verify_keys: VerifyKeys = [(id.try_into()?, verify_key)].into(); + + Ok((keypair, verify_keys)) +} + +fn load(db: &Arc) -> Result> { + let (version, key) = db["global"] + .get_blocking(b"keypair") + .map(|ref val| { + // database deserializer is having trouble with this so it's manual for now + let mut elems = val.split(|&b| b == b'\xFF'); + let vlen = elems.next().expect("invalid keypair entry").len(); + let ver = string_from_bytes(&val[..vlen]).expect("invalid keypair version"); + let der = val[vlen.saturating_add(1)..].to_vec(); + debug!("Found existing Ed25519 keypair: {ver:?}"); + (ver, der) + }) + .or_else(|e| { + assert!(e.is_not_found(), "unexpected error fetching keypair"); + create(db) + })?; + + let key = + Ed25519KeyPair::from_der(&key, version).map_err(|e| err!("Failed to load ed25519 keypair from der: {e:?}"))?; + + Ok(Box::new(key)) +} + +fn create(db: &Arc) -> Result<(String, Vec)> { + let keypair = Ed25519KeyPair::generate().map_err(|e| err!("Failed to generate new ed25519 keypair: {e:?}"))?; + + let id = utils::rand::string(8); + debug_info!("Generated new Ed25519 keypair: {id:?}"); + + let value: (String, Vec) = (id, keypair.to_vec()); + db["global"].raw_put(b"keypair", &value); + + Ok(value) +} + +#[inline] +fn remove(db: &Arc) { + let global = &db["global"]; + global.remove(b"keypair"); +} diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index ae2b8c3cb..c3b84cb33 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -1,31 +1,30 @@ -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - sync::Arc, - time::{Duration, SystemTime}, -}; +mod acquire; +mod get; +mod keypair; +mod request; +mod sign; +mod verify; + +use std::{collections::BTreeMap, sync::Arc, time::Duration}; -use conduit::{debug, debug_error, debug_warn, err, error, info, trace, warn, Err, Result}; -use futures::{stream::FuturesUnordered, StreamExt}; +use conduit::{implement, utils::time::timepoint_from_now, Result}; +use database::{Deserialized, Json, Map}; use ruma::{ - api::federation::{ - discovery::{ - get_remote_server_keys, - get_remote_server_keys_batch::{self, v2::QueryCriteria}, - get_server_keys, - }, - membership::create_join_event, - }, - serde::Base64, - CanonicalJsonObject, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedServerSigningKeyId, - RoomVersionId, ServerName, + api::federation::discovery::{ServerSigningKeys, VerifyKey}, + serde::Raw, + signatures::{Ed25519KeyPair, PublicKeyMap, PublicKeySet}, + MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, }; use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::{RwLock, RwLockWriteGuard}; use crate::{globals, sending, Dep}; pub struct Service { + keypair: Box, + verify_keys: VerifyKeys, + minimum_valid: Duration, services: Services, + db: Data, } struct Services { @@ -33,546 +32,135 @@ struct Services { sending: Dep, } +struct Data { + server_signingkeys: Arc, +} + +pub type VerifyKeys = BTreeMap; +pub type PubKeyMap = PublicKeyMap; +pub type PubKeys = PublicKeySet; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { + let minimum_valid = Duration::from_secs(3600); + let (keypair, verify_keys) = keypair::init(args.db)?; + Ok(Arc::new(Self { + keypair, + verify_keys, + minimum_valid, services: Services { globals: args.depend::("globals"), sending: args.depend::("sending"), }, + db: Data { + server_signingkeys: args.db["server_signingkeys"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub async fn fetch_required_signing_keys<'a, E>( - &'a self, events: E, pub_key_map: &RwLock>>, - ) -> Result<()> - where - E: IntoIterator> + Send, - { - let mut server_key_ids = HashMap::new(); - for event in events { - for (signature_server, signature) in event - .get("signatures") - .ok_or(err!(BadServerResponse("No signatures in server response pdu.")))? - .as_object() - .ok_or(err!(BadServerResponse("Invalid signatures object in server response pdu.")))? - { - let signature_object = signature.as_object().ok_or(err!(BadServerResponse( - "Invalid signatures content object in server response pdu.", - )))?; - - for signature_id in signature_object.keys() { - server_key_ids - .entry(signature_server.clone()) - .or_insert_with(HashSet::new) - .insert(signature_id.clone()); - } - } - } - - if server_key_ids.is_empty() { - // Nothing to do, can exit early - trace!("server_key_ids is empty, not fetching any keys"); - return Ok(()); - } - - trace!( - "Fetch keys for {}", - server_key_ids - .keys() - .cloned() - .collect::>() - .join(", ") - ); - - let mut server_keys: FuturesUnordered<_> = server_key_ids - .into_iter() - .map(|(signature_server, signature_ids)| async { - let fetch_res = self - .fetch_signing_keys_for_server( - signature_server.as_str().try_into().map_err(|e| { - ( - signature_server.clone(), - err!(BadServerResponse( - "Invalid servername in signatures of server response pdu: {e:?}" - )), - ) - })?, - signature_ids.into_iter().collect(), // HashSet to Vec - ) - .await; - - match fetch_res { - Ok(keys) => Ok((signature_server, keys)), - Err(e) => { - debug_error!( - "Signature verification failed: Could not fetch signing key for {signature_server}: {e}", - ); - Err((signature_server, e)) - }, - } - }) - .collect(); - - while let Some(fetch_res) = server_keys.next().await { - match fetch_res { - Ok((signature_server, keys)) => { - pub_key_map - .write() - .await - .insert(signature_server.clone(), keys); - }, - Err((signature_server, e)) => { - debug_warn!("Failed to fetch keys for {signature_server}: {e:?}"); - }, - } - } - - Ok(()) - } - - // Gets a list of servers for which we don't have the signing key yet. We go - // over the PDUs and either cache the key or add it to the list that needs to be - // retrieved. - async fn get_server_keys_from_cache( - &self, pdu: &RawJsonValue, - servers: &mut BTreeMap>, - _room_version: &RoomVersionId, - pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap>>, - ) -> Result<()> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - debug_error!("Invalid PDU in server response: {pdu:#?}"); - err!(BadServerResponse(error!("Invalid PDU in server response: {e:?}"))) - })?; - - let signatures = value - .get("signatures") - .ok_or(err!(BadServerResponse("No signatures in server response pdu.")))? - .as_object() - .ok_or(err!(BadServerResponse("Invalid signatures object in server response pdu.")))?; - - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(err!(BadServerResponse( - "Invalid signatures content object in server response pdu.", - )))?; - - let signature_ids = signature_object.keys().cloned().collect::>(); - - let contains_all_ids = - |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - - let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|e| { - err!(BadServerResponse( - "Invalid servername in signatures of server response pdu: {e:?}" - )) - })?; - - if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { - continue; - } - - debug!("Loading signing keys for {origin}"); - let result: BTreeMap<_, _> = self - .services - .globals - .verify_keys_for(origin) - .await? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if !contains_all_ids(&result) { - debug_warn!("Signing key not loaded for {origin}"); - servers.insert(origin.to_owned(), BTreeMap::new()); - } +#[implement(Service)] +#[inline] +pub fn keypair(&self) -> &Ed25519KeyPair { &self.keypair } + +#[implement(Service)] +async fn add_signing_keys(&self, new_keys: ServerSigningKeys) { + let origin = &new_keys.server_name; + + // (timo) Not atomic, but this is not critical + let mut keys: ServerSigningKeys = self + .db + .server_signingkeys + .get(origin) + .await + .deserialized() + .unwrap_or_else(|_| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); + + keys.verify_keys.extend(new_keys.verify_keys); + keys.old_verify_keys.extend(new_keys.old_verify_keys); + self.db.server_signingkeys.raw_put(origin, Json(&keys)); +} - pub_key_map.insert(origin.to_string(), result); +#[implement(Service)] +async fn verify_key_exists(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> bool { + type KeysMap<'a> = BTreeMap<&'a ServerSigningKeyId, &'a RawJsonValue>; + + let Ok(keys) = self + .db + .server_signingkeys + .get(origin) + .await + .deserialized::>() + else { + return false; + }; + + if let Ok(Some(verify_keys)) = keys.get_field::>("verify_keys") { + if verify_keys.contains_key(key_id) { + return true; } - - Ok(()) } - /// Batch requests homeserver signing keys from trusted notary key servers - /// (`trusted_servers` config option) - async fn batch_request_signing_keys( - &self, mut servers: BTreeMap>, - pub_key_map: &RwLock>>, - ) -> Result<()> { - for server in self.services.globals.trusted_servers() { - debug!("Asking batch signing keys from trusted server {server}"); - match self - .services - .sending - .send_federation_request( - server, - get_remote_server_keys_batch::v2::Request { - server_keys: servers.clone(), - }, - ) - .await - { - Ok(keys) => { - debug!("Got signing keys: {keys:?}"); - let mut pkm = pub_key_map.write().await; - for k in keys.server_keys { - let k = match k.deserialize() { - Ok(key) => key, - Err(e) => { - warn!( - "Received error {e} while fetching keys from trusted server {server}: {:#?}", - k.into_json() - ); - continue; - }, - }; - - // TODO: Check signature from trusted server? - servers.remove(&k.server_name); - - let result = self - .services - .globals - .db - .add_signing_key(&k.server_name, k.clone()) - .await - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect::>(); - - pkm.insert(k.server_name.to_string(), result); - } - }, - Err(e) => error!( - "Failed sending batched key request to trusted key server {server} for the remote servers \ - {servers:?}: {e}" - ), - } + if let Ok(Some(old_verify_keys)) = keys.get_field::>("old_verify_keys") { + if old_verify_keys.contains_key(key_id) { + return true; } - - Ok(()) } - /// Requests multiple homeserver signing keys from individual servers (not - /// trused notary servers) - async fn request_signing_keys( - &self, servers: BTreeMap>, - pub_key_map: &RwLock>>, - ) -> Result<()> { - debug!("Asking individual servers for signing keys: {servers:?}"); - let mut futures: FuturesUnordered<_> = servers - .into_keys() - .map(|server| async move { - ( - self.services - .sending - .send_federation_request(&server, get_server_keys::v2::Request::new()) - .await, - server, - ) - }) - .collect(); + false +} - while let Some(result) = futures.next().await { - debug!("Received new Future result"); - if let (Ok(get_keys_response), origin) = result { - debug!("Result is from {origin}"); - if let Ok(key) = get_keys_response.server_key.deserialize() { - let result: BTreeMap<_, _> = self - .services - .globals - .db - .add_signing_key(&origin, key) - .await - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - pub_key_map.write().await.insert(origin.to_string(), result); - } - } - debug!("Done handling Future result"); - } +#[implement(Service)] +pub async fn verify_keys_for(&self, origin: &ServerName) -> VerifyKeys { + let mut keys = self + .signing_keys_for(origin) + .await + .map(|keys| merge_old_keys(keys).verify_keys) + .unwrap_or(BTreeMap::new()); - Ok(()) + if self.services.globals.server_is_ours(origin) { + keys.extend(self.verify_keys.clone().into_iter()); } - pub async fn fetch_join_signing_keys( - &self, event: &create_join_event::v2::Response, room_version: &RoomVersionId, - pub_key_map: &RwLock>>, - ) -> Result<()> { - let mut servers: BTreeMap> = BTreeMap::new(); - - { - let mut pkm = pub_key_map.write().await; - - // Try to fetch keys, failure is okay. Servers we couldn't find in the cache - // will be added to `servers` - for pdu in event - .room_state - .state - .iter() - .chain(&event.room_state.auth_chain) - { - if let Err(error) = self - .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) - .await - { - debug!(%error, "failed to get server keys from cache"); - }; - } - - drop(pkm); - }; - - if servers.is_empty() { - trace!("We had all keys cached locally, not fetching any keys from remote servers"); - return Ok(()); - } - - if self.services.globals.query_trusted_key_servers_first() { - info!( - "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ - homeserver signing keys." - ); - - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - debug!("Trusted server supplied all signing keys, no more keys to fetch"); - return Ok(()); - } - - debug!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - } else { - debug!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - debug!("Individual homeservers supplied all signing keys, no more keys to fetch"); - return Ok(()); - } - - debug!("Remaining servers left the individual homeservers did not provide: {servers:?}"); - - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - } - - debug!("Search for signing keys done"); - - /*if servers.is_empty() { - warn!("Failed to find homeserver signing keys for the remaining servers: {servers:?}"); - }*/ + keys +} - Ok(()) - } +#[implement(Service)] +pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { + self.db.server_signingkeys.get(origin).await.deserialized() +} - /// Search the DB for the signing keys of the given server, if we don't have - /// them fetch them from the server and save to our DB. - #[tracing::instrument(skip_all)] - pub async fn fetch_signing_keys_for_server( - &self, origin: &ServerName, signature_ids: Vec, - ) -> Result> { - let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); +#[implement(Service)] +fn minimum_valid_ts(&self) -> MilliSecondsSinceUnixEpoch { + let timepoint = timepoint_from_now(self.minimum_valid).expect("SystemTime should not overflow"); + MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow") +} - let mut result: BTreeMap<_, _> = self - .services - .globals - .verify_keys_for(origin) - .await? +fn merge_old_keys(mut keys: ServerSigningKeys) -> ServerSigningKeys { + keys.verify_keys.extend( + keys.old_verify_keys + .clone() .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if contains_all_ids(&result) { - trace!("We have all homeserver signing keys locally for {origin}, not fetching any remotely"); - return Ok(result); - } - - // i didnt split this out into their own functions because it's relatively small - if self.services.globals.query_trusted_key_servers_first() { - info!( - "query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \ - keys" - ); - - for server in self.services.globals.trusted_servers() { - debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = self - .services - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin.to_owned(), - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime too large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::>() - }) { - debug!("Got signing keys: {:?}", server_keys); - for k in server_keys { - self.services - .globals - .db - .add_signing_key(origin, k.clone()) - .await; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - - debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = self - .services - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - self.services - .globals - .db - .add_signing_key(origin, server_key.clone()) - .await; + .map(|(key_id, old)| (key_id, VerifyKey::new(old.key))), + ); - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - } else { - info!("query_trusted_key_servers_first is set to false, querying {origin} first"); - debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = self - .services - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - self.services - .globals - .db - .add_signing_key(origin, server_key.clone()) - .await; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - - for server in self.services.globals.trusted_servers() { - debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = self - .services - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin.to_owned(), - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime too large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::>() - }) { - debug!("Got signing keys: {server_keys:?}"); - for k in server_keys { - self.services - .globals - .db - .add_signing_key(origin, k.clone()) - .await; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } + keys +} - if contains_all_ids(&result) { - return Ok(result); - } - } - } - } +fn extract_key(mut keys: ServerSigningKeys, key_id: &ServerSigningKeyId) -> Option { + keys.verify_keys.remove(key_id).or_else(|| { + keys.old_verify_keys + .remove(key_id) + .map(|old| VerifyKey::new(old.key)) + }) +} - Err!(BadServerResponse(warn!("Failed to find public key for server {origin:?}"))) - } +fn key_exists(keys: &ServerSigningKeys, key_id: &ServerSigningKeyId) -> bool { + keys.verify_keys.contains_key(key_id) || keys.old_verify_keys.contains_key(key_id) } diff --git a/src/service/server_keys/request.rs b/src/service/server_keys/request.rs new file mode 100644 index 000000000..84dd28717 --- /dev/null +++ b/src/service/server_keys/request.rs @@ -0,0 +1,97 @@ +use std::collections::BTreeMap; + +use conduit::{implement, Err, Result}; +use ruma::{ + api::federation::discovery::{ + get_remote_server_keys, + get_remote_server_keys_batch::{self, v2::QueryCriteria}, + get_server_keys, ServerSigningKeys, + }, + OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, +}; + +#[implement(super::Service)] +pub(super) async fn batch_notary_request<'a, S, K>( + &self, notary: &ServerName, batch: S, +) -> Result> +where + S: Iterator + Send, + K: Iterator + Send, +{ + use get_remote_server_keys_batch::v2::Request; + type RumaBatch = BTreeMap>; + + let criteria = QueryCriteria { + minimum_valid_until_ts: Some(self.minimum_valid_ts()), + }; + + let mut server_keys = RumaBatch::new(); + for (server, key_ids) in batch { + let entry = server_keys.entry(server.into()).or_default(); + for key_id in key_ids { + entry.insert(key_id.into(), criteria.clone()); + } + } + + debug_assert!(!server_keys.is_empty(), "empty batch request to notary"); + let request = Request { + server_keys, + }; + + self.services + .sending + .send_federation_request(notary, request) + .await + .map(|response| response.server_keys) + .map(|keys| { + keys.into_iter() + .map(|key| key.deserialize()) + .filter_map(Result::ok) + .collect() + }) +} + +#[implement(super::Service)] +pub async fn notary_request(&self, notary: &ServerName, target: &ServerName) -> Result> { + use get_remote_server_keys::v2::Request; + + let request = Request { + server_name: target.into(), + minimum_valid_until_ts: self.minimum_valid_ts(), + }; + + self.services + .sending + .send_federation_request(notary, request) + .await + .map(|response| response.server_keys) + .map(|keys| { + keys.into_iter() + .map(|key| key.deserialize()) + .filter_map(Result::ok) + .collect() + }) +} + +#[implement(super::Service)] +pub async fn server_request(&self, target: &ServerName) -> Result { + use get_server_keys::v2::Request; + + let server_signing_key = self + .services + .sending + .send_federation_request(target, Request::new()) + .await + .map(|response| response.server_key) + .and_then(|key| key.deserialize().map_err(Into::into))?; + + if server_signing_key.server_name != target { + return Err!(BadServerResponse(debug_warn!( + requested = ?target, + response = ?server_signing_key.server_name, + "Server responded with bogus server_name" + ))); + } + + Ok(server_signing_key) +} diff --git a/src/service/server_keys/sign.rs b/src/service/server_keys/sign.rs new file mode 100644 index 000000000..28fd7e803 --- /dev/null +++ b/src/service/server_keys/sign.rs @@ -0,0 +1,18 @@ +use conduit::{implement, Result}; +use ruma::{CanonicalJsonObject, RoomVersionId}; + +#[implement(super::Service)] +pub fn sign_json(&self, object: &mut CanonicalJsonObject) -> Result { + use ruma::signatures::sign_json; + + let server_name = self.services.globals.server_name().as_str(); + sign_json(server_name, self.keypair(), object).map_err(Into::into) +} + +#[implement(super::Service)] +pub fn hash_and_sign_event(&self, object: &mut CanonicalJsonObject, room_version: &RoomVersionId) -> Result { + use ruma::signatures::hash_and_sign_event; + + let server_name = self.services.globals.server_name().as_str(); + hash_and_sign_event(server_name, self.keypair(), object, room_version).map_err(Into::into) +} diff --git a/src/service/server_keys/verify.rs b/src/service/server_keys/verify.rs new file mode 100644 index 000000000..ad20fec7f --- /dev/null +++ b/src/service/server_keys/verify.rs @@ -0,0 +1,33 @@ +use conduit::{implement, pdu::gen_event_id_canonical_json, Err, Result}; +use ruma::{signatures::Verified, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId}; +use serde_json::value::RawValue as RawJsonValue; + +#[implement(super::Service)] +pub async fn validate_and_add_event_id( + &self, pdu: &RawJsonValue, room_version: &RoomVersionId, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?; + if let Err(e) = self.verify_event(&value, Some(room_version)).await { + return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}"))); + } + + value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into())); + + Ok((event_id, value)) +} + +#[implement(super::Service)] +pub async fn verify_event( + &self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>, +) -> Result { + let room_version = room_version.unwrap_or(&RoomVersionId::V11); + let keys = self.get_event_keys(event, room_version).await?; + ruma::signatures::verify_event(&keys, event, room_version).map_err(Into::into) +} + +#[implement(super::Service)] +pub async fn verify_json(&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>) -> Result { + let room_version = room_version.unwrap_or(&RoomVersionId::V11); + let keys = self.get_event_keys(event, room_version).await?; + ruma::signatures::verify_json(&keys, event.clone()).map_err(Into::into) +} From b4ec1e9d3cbc58f68c3733061c11c55700ff3018 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 14 Oct 2024 03:58:25 +0000 Subject: [PATCH 084/245] add federation client for select high-timeout requests Signed-off-by: Jason Volk --- src/api/client/membership.rs | 4 ++-- src/core/config/mod.rs | 2 +- src/service/client/mod.rs | 9 ++++++++- src/service/rooms/event_handler/mod.rs | 2 +- src/service/sending/mod.rs | 11 +++++++++++ 5 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 2fa34ff7b..31fd90766 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -775,7 +775,7 @@ async fn join_room_by_id_helper_remote( let send_join_response = services .sending - .send_federation_request(&remote_server, send_join_request) + .send_synapse_request(&remote_server, send_join_request) .await?; info!("send_join finished"); @@ -1154,7 +1154,7 @@ async fn join_room_by_id_helper_local( let send_join_response = services .sending - .send_federation_request( + .send_synapse_request( &remote_server, federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 114c6e766..02b277d0b 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1775,7 +1775,7 @@ fn default_well_known_conn_timeout() -> u64 { 6 } fn default_well_known_timeout() -> u64 { 10 } -fn default_federation_timeout() -> u64 { 300 } +fn default_federation_timeout() -> u64 { 25 } fn default_federation_idle_timeout() -> u64 { 25 } diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index b21f9dab5..f9a89e99d 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -11,6 +11,7 @@ pub struct Service { pub extern_media: reqwest::Client, pub well_known: reqwest::Client, pub federation: reqwest::Client, + pub synapse: reqwest::Client, pub sender: reqwest::Client, pub appservice: reqwest::Client, pub pusher: reqwest::Client, @@ -48,12 +49,18 @@ impl crate::Service for Service { federation: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) .read_timeout(Duration::from_secs(config.federation_timeout)) - .timeout(Duration::from_secs(config.federation_timeout)) .pool_max_idle_per_host(config.federation_idle_per_host.into()) .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) .redirect(redirect::Policy::limited(3)) .build()?, + synapse: base(config)? + .dns_resolver(resolver.resolver.hooked.clone()) + .read_timeout(Duration::from_secs(305)) + .pool_max_idle_per_host(0) + .redirect(redirect::Policy::limited(3)) + .build()?, + sender: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) .read_timeout(Duration::from_secs(config.sender_timeout)) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 8448404ba..0ffd9659b 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1016,7 +1016,7 @@ impl Service { let res = self .services .sending - .send_federation_request( + .send_synapse_request( origin, get_room_state_ids::v1::Request { room_id: room_id.to_owned(), diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 5970c3836..63c5e655a 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -245,6 +245,7 @@ impl Service { .await } + /// Sends a request to a federation server #[tracing::instrument(skip_all, name = "request")] pub async fn send_federation_request(&self, dest: &ServerName, request: T) -> Result where @@ -254,6 +255,16 @@ impl Service { self.send(client, dest, request).await } + /// Like send_federation_request() but with a very large timeout + #[tracing::instrument(skip_all, name = "synapse")] + pub async fn send_synapse_request(&self, dest: &ServerName, request: T) -> Result + where + T: OutgoingRequest + Debug + Send, + { + let client = &self.services.client.synapse; + self.send(client, dest, request).await + } + /// Sends a request to an appservice /// /// Only returns None if there is no url specified in the appservice From d0ee4b6d253079ceef7ee9094db9e3f70f1ed048 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 14 Oct 2024 01:01:12 +0000 Subject: [PATCH 085/245] add resolve_with_servers() to alias service; simplify api Signed-off-by: Jason Volk --- src/admin/user/commands.rs | 9 ++- src/api/client/alias.rs | 16 ++--- src/api/client/membership.rs | 66 +++++++++-------- src/service/rooms/alias/mod.rs | 47 +++++++----- src/service/rooms/alias/remote.rs | 116 ++++++++++++++---------------- 5 files changed, 131 insertions(+), 123 deletions(-) diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index df3938331..fb6d2bf1b 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -381,13 +381,18 @@ pub(super) async fn force_join_room( &self, user_id: String, room_id: OwnedRoomOrAliasId, ) -> Result { let user_id = parse_local_user_id(self.services, &user_id)?; - let room_id = self.services.rooms.alias.resolve(&room_id).await?; + let (room_id, servers) = self + .services + .rooms + .alias + .resolve_with_servers(&room_id, None) + .await?; assert!( self.services.globals.user_is_local(&user_id), "Parsed user_id must be a local user" ); - join_room_by_id_helper(self.services, &user_id, &room_id, None, &[], None, &None).await?; + join_room_by_id_helper(self.services, &user_id, &room_id, None, &servers, None, &None).await?; Ok(RoomMessageEventContent::notice_markdown(format!( "{user_id} has been joined to {room_id}.", diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 2399a3551..83f3291d4 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -86,25 +86,19 @@ pub(crate) async fn get_alias_route( State(services): State, body: Ruma, ) -> Result { let room_alias = body.body.room_alias; - let servers = None; - let Ok((room_id, pre_servers)) = services - .rooms - .alias - .resolve_alias(&room_alias, servers.as_ref()) - .await - else { + let Ok((room_id, servers)) = services.rooms.alias.resolve_alias(&room_alias, None).await else { return Err!(Request(NotFound("Room with alias not found."))); }; - let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers).await; + let servers = room_available_servers(&services, &room_id, &room_alias, servers).await; debug!(?room_alias, ?room_id, "available servers: {servers:?}"); Ok(get_alias::v3::Response::new(room_id, servers)) } async fn room_available_servers( - services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option>, + services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: Vec, ) -> Vec { // find active servers in room state cache to suggest let mut servers: Vec = services @@ -117,9 +111,7 @@ async fn room_available_servers( // push any servers we want in the list already (e.g. responded remote alias // servers, room alias server itself) - if let Some(pre_servers) = pre_servers { - servers.extend(pre_servers.clone()); - }; + servers.extend(pre_servers); servers.sort_unstable(); servers.dedup(); diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 31fd90766..27de60c68 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -9,8 +9,9 @@ use axum_client_ip::InsecureClientIp; use conduit::{ debug, debug_info, debug_warn, err, error, info, pdu, pdu::{gen_event_id_canonical_json, PduBuilder}, + result::FlatOk, trace, utils, - utils::{IterStream, ReadyExt}, + utils::{shuffle, IterStream, ReadyExt}, warn, Err, Error, PduEvent, Result, }; use futures::{FutureExt, StreamExt}; @@ -188,6 +189,10 @@ pub(crate) async fn join_room_by_id_route( servers.push(server.into()); } + servers.sort_unstable(); + servers.dedup(); + shuffle(&mut servers); + join_room_by_id_helper( &services, sender_user, @@ -251,45 +256,48 @@ pub(crate) async fn join_room_by_id_or_alias_route( servers.push(server.to_owned()); } + servers.sort_unstable(); + servers.dedup(); + shuffle(&mut servers); + (servers, room_id) }, Err(room_alias) => { - let response = services + let (room_id, mut servers) = services .rooms .alias - .resolve_alias(&room_alias, Some(&body.via.clone())) + .resolve_alias(&room_alias, Some(body.via.clone())) .await?; - let (room_id, mut pre_servers) = response; banned_room_check(&services, sender_user, Some(&room_id), Some(room_alias.server_name()), client).await?; - let mut servers = body.via; - if let Some(pre_servers) = &mut pre_servers { - servers.append(pre_servers); - } + let addl_via_servers = services + .rooms + .state_cache + .servers_invite_via(&room_id) + .map(ToOwned::to_owned); - servers.extend( - services - .rooms - .state_cache - .servers_invite_via(&room_id) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); + let addl_state_servers = services + .rooms + .state_cache + .invite_state(sender_user, &room_id) + .await + .unwrap_or_default(); + + let mut addl_servers: Vec<_> = addl_state_servers + .iter() + .map(|event| event.get_field("sender")) + .filter_map(FlatOk::flat_ok) + .map(|user: &UserId| user.server_name().to_owned()) + .stream() + .chain(addl_via_servers) + .collect() + .await; - servers.extend( - services - .rooms - .state_cache - .invite_state(sender_user, &room_id) - .await - .unwrap_or_default() - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + addl_servers.sort_unstable(); + addl_servers.dedup(); + shuffle(&mut addl_servers); + servers.append(&mut addl_servers); (servers, room_id) }, diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 3f944729e..0cdec8eeb 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -112,40 +112,51 @@ impl Service { Ok(()) } + #[inline] pub async fn resolve(&self, room: &RoomOrAliasId) -> Result { + self.resolve_with_servers(room, None) + .await + .map(|(room_id, _)| room_id) + } + + pub async fn resolve_with_servers( + &self, room: &RoomOrAliasId, servers: Option>, + ) -> Result<(OwnedRoomId, Vec)> { if room.is_room_id() { - let room_id: &RoomId = &RoomId::parse(room).expect("valid RoomId"); - Ok(room_id.to_owned()) + let room_id = RoomId::parse(room).expect("valid RoomId"); + Ok((room_id, servers.unwrap_or_default())) } else { - let alias: &RoomAliasId = &RoomAliasId::parse(room).expect("valid RoomAliasId"); - Ok(self.resolve_alias(alias, None).await?.0) + let alias = &RoomAliasId::parse(room).expect("valid RoomAliasId"); + self.resolve_alias(alias, servers).await } } #[tracing::instrument(skip(self), name = "resolve")] pub async fn resolve_alias( - &self, room_alias: &RoomAliasId, servers: Option<&Vec>, - ) -> Result<(OwnedRoomId, Option>)> { - if !self - .services - .globals - .server_is_ours(room_alias.server_name()) - && (!servers + &self, room_alias: &RoomAliasId, servers: Option>, + ) -> Result<(OwnedRoomId, Vec)> { + let server_name = room_alias.server_name(); + let server_is_ours = self.services.globals.server_is_ours(server_name); + let servers_contains_ours = || { + servers .as_ref() - .is_some_and(|servers| servers.contains(&self.services.globals.server_name().to_owned())) - || servers.as_ref().is_none()) - { - return self.remote_resolve(room_alias, servers).await; + .is_some_and(|servers| servers.contains(&self.services.globals.config.server_name)) + }; + + if !server_is_ours && !servers_contains_ours() { + return self + .remote_resolve(room_alias, servers.unwrap_or_default()) + .await; } - let room_id: Option = match self.resolve_local_alias(room_alias).await { + let room_id = match self.resolve_local_alias(room_alias).await { Ok(r) => Some(r), Err(_) => self.resolve_appservice_alias(room_alias).await?, }; room_id.map_or_else( - || Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")), - |room_id| Ok((room_id, None)), + || Err!(Request(NotFound("Room with alias not found."))), + |room_id| Ok((room_id, Vec::new())), ) } diff --git a/src/service/rooms/alias/remote.rs b/src/service/rooms/alias/remote.rs index 5d835240b..d9acccc9c 100644 --- a/src/service/rooms/alias/remote.rs +++ b/src/service/rooms/alias/remote.rs @@ -1,75 +1,67 @@ -use conduit::{debug, debug_warn, Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation}, - OwnedRoomId, OwnedServerName, RoomAliasId, -}; +use std::iter::once; -impl super::Service { - pub(super) async fn remote_resolve( - &self, room_alias: &RoomAliasId, servers: Option<&Vec>, - ) -> Result<(OwnedRoomId, Option>)> { - debug!(?room_alias, ?servers, "resolve"); +use conduit::{debug, debug_error, err, implement, Result}; +use federation::query::get_room_information::v1::Response; +use ruma::{api::federation, OwnedRoomId, OwnedServerName, RoomAliasId, ServerName}; - let mut response = self - .services - .sending - .send_federation_request( - room_alias.server_name(), - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; +#[implement(super::Service)] +pub(super) async fn remote_resolve( + &self, room_alias: &RoomAliasId, servers: Vec, +) -> Result<(OwnedRoomId, Vec)> { + debug!(?room_alias, servers = ?servers, "resolve"); + let servers = once(room_alias.server_name()) + .map(ToOwned::to_owned) + .chain(servers.into_iter()); - debug!("room alias server_name get_alias_helper response: {response:?}"); + let mut resolved_servers = Vec::new(); + let mut resolved_room_id: Option = None; + for server in servers { + match self.remote_request(room_alias, &server).await { + Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"), + Ok(Response { + room_id, + servers, + }) => { + debug!("Server {server} answered with {room_id:?} for {room_alias:?} servers: {servers:?}"); - if let Err(ref e) = response { - debug_warn!( - "Server {} of the original room alias failed to assist in resolving room alias: {e}", - room_alias.server_name(), - ); - } - - if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() { - if let Some(servers) = servers { - for server in servers { - response = self - .services - .sending - .send_federation_request( - server, - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; - debug!("Got response from server {server} for room aliases: {response:?}"); + resolved_room_id.get_or_insert(room_id); + add_server(&mut resolved_servers, server); - if let Ok(ref response) = response { - if !response.servers.is_empty() { - break; - } - debug_warn!( - "Server {server} responded with room aliases, but was empty? Response: {response:?}" - ); - } + if !servers.is_empty() { + add_servers(&mut resolved_servers, servers); + break; } - } + }, } + } - if let Ok(response) = response { - let room_id = response.room_id; + resolved_room_id + .map(|room_id| (room_id, resolved_servers)) + .ok_or_else(|| err!(Request(NotFound("No servers could assist in resolving the room alias")))) +} - let mut pre_servers = response.servers; - // since the room alis server responded, insert it into the list - pre_servers.push(room_alias.server_name().into()); +#[implement(super::Service)] +async fn remote_request(&self, room_alias: &RoomAliasId, server: &ServerName) -> Result { + use federation::query::get_room_information::v1::Request; - return Ok((room_id, Some(pre_servers))); - } + let request = Request { + room_alias: room_alias.to_owned(), + }; + + self.services + .sending + .send_federation_request(server, request) + .await +} + +fn add_servers(servers: &mut Vec, new: Vec) { + for server in new { + add_server(servers, server); + } +} - Err(Error::BadRequest( - ErrorKind::NotFound, - "No servers could assist in resolving the room alias", - )) +fn add_server(servers: &mut Vec, server: OwnedServerName) { + if !servers.contains(&server) { + servers.push(server); } } From ed5b5d7877996f0ca4862ee3b08cfebd35744904 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 15 Oct 2024 09:34:43 +0000 Subject: [PATCH 086/245] merge rooms state service and data Signed-off-by: Jason Volk --- src/service/rooms/state/data.rs | 65 ----------------------------- src/service/rooms/state/mod.rs | 73 +++++++++++++++++++++++---------- 2 files changed, 52 insertions(+), 86 deletions(-) delete mode 100644 src/service/rooms/state/data.rs diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs deleted file mode 100644 index 813f48aed..000000000 --- a/src/service/rooms/state/data.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::sync::Arc; - -use conduit::{ - utils::{stream::TryIgnore, ReadyExt}, - Result, -}; -use database::{Database, Deserialized, Interfix, Map}; -use ruma::{OwnedEventId, RoomId}; - -use super::RoomMutexGuard; - -pub(super) struct Data { - shorteventid_shortstatehash: Arc, - roomid_shortstatehash: Arc, - pub(super) roomid_pduleaves: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), - roomid_shortstatehash: db["roomid_shortstatehash"].clone(), - roomid_pduleaves: db["roomid_pduleaves"].clone(), - } - } - - pub(super) async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { - self.roomid_shortstatehash.get(room_id).await.deserialized() - } - - #[inline] - pub(super) fn set_room_state( - &self, - room_id: &RoomId, - new_shortstatehash: u64, - _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) { - self.roomid_shortstatehash - .raw_put(room_id, new_shortstatehash); - } - - pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) { - self.shorteventid_shortstatehash - .put(shorteventid, shortstatehash); - } - - pub(super) async fn set_forward_extremities( - &self, - room_id: &RoomId, - event_ids: Vec, - _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) { - let prefix = (room_id, Interfix); - self.roomid_pduleaves - .keys_prefix_raw(&prefix) - .ignore_err() - .ready_for_each(|key| self.roomid_pduleaves.remove(key)) - .await; - - for event_id in &event_ids { - let key = (room_id, event_id); - self.roomid_pduleaves.put_raw(key, event_id); - } - } -} diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index cfcb2da6f..6abaa1980 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,5 +1,3 @@ -mod data; - use std::{ collections::{HashMap, HashSet}, fmt::Write, @@ -10,11 +8,10 @@ use std::{ use conduit::{ err, result::FlatOk, - utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard}, + utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, warn, PduEvent, Result, }; -use data::Data; -use database::{Ignore, Interfix}; +use database::{Deserialized, Ignore, Interfix, Map}; use futures::{future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ events::{ @@ -30,9 +27,9 @@ use super::state_compressor::CompressedStateEvent; use crate::{globals, rooms, Dep}; pub struct Service { + pub mutex: RoomMutexMap, services: Services, db: Data, - pub mutex: RoomMutexMap, } struct Services { @@ -45,12 +42,19 @@ struct Services { timeline: Dep, } +struct Data { + shorteventid_shortstatehash: Arc, + roomid_shortstatehash: Arc, + roomid_pduleaves: Arc, +} + type RoomMutexMap = MutexMap; pub type RoomMutexGuard = MutexMapGuard; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + mutex: RoomMutexMap::new(), services: Services { globals: args.depend::("globals"), short: args.depend::("rooms::short"), @@ -60,12 +64,15 @@ impl crate::Service for Service { state_compressor: args.depend::("rooms::state_compressor"), timeline: args.depend::("rooms::timeline"), }, - db: Data::new(args.db), - mutex: RoomMutexMap::new(), + db: Data { + shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(), + roomid_shortstatehash: args.db["roomid_shortstatehash"].clone(), + roomid_pduleaves: args.db["roomid_pduleaves"].clone(), + }, })) } - fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + fn memory_usage(&self, out: &mut dyn Write) -> Result { let mutex = self.mutex.len(); writeln!(out, "state_mutex: {mutex}")?; @@ -84,7 +91,7 @@ impl Service { statediffnew: Arc>, _statediffremoved: Arc>, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) -> Result { let event_ids = statediffnew.iter().stream().filter_map(|new| { self.services .state_compressor @@ -127,7 +134,7 @@ impl Service { self.services.state_cache.update_joined_count(room_id).await; - self.db.set_room_state(room_id, shortstatehash, state_lock); + self.set_room_state(room_id, shortstatehash, state_lock); Ok(()) } @@ -140,13 +147,15 @@ impl Service { pub async fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { + const BUFSIZE: usize = size_of::(); + let shorteventid = self .services .short .get_or_create_shorteventid(event_id) .await; - let previous_shortstatehash = self.db.get_room_shortstatehash(room_id).await; + let previous_shortstatehash = self.get_room_shortstatehash(room_id).await; let state_hash = calculate_hash( &state_ids_compressed @@ -196,7 +205,9 @@ impl Service { )?; } - self.db.set_event_state(shorteventid, shortstatehash); + self.db + .shorteventid_shortstatehash + .aput::(shorteventid, shortstatehash); Ok(shortstatehash) } @@ -207,6 +218,8 @@ impl Service { /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + const BUFSIZE: usize = size_of::(); + let shorteventid = self .services .short @@ -216,7 +229,9 @@ impl Service { let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await; if let Ok(p) = previous_shortstatehash { - self.db.set_event_state(shorteventid, p); + self.db + .shorteventid_shortstatehash + .aput::(shorteventid, p); } if let Some(state_key) = &new_pdu.state_key { @@ -306,14 +321,18 @@ impl Service { } /// Set the state hash to a new version, but does not update state_cache. - #[tracing::instrument(skip(self, mutex_lock), level = "debug")] + #[tracing::instrument(skip(self, _mutex_lock), level = "debug")] pub fn set_room_state( &self, room_id: &RoomId, shortstatehash: u64, - mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) { - self.db.set_room_state(room_id, shortstatehash, mutex_lock); + const BUFSIZE: usize = size_of::(); + + self.db + .roomid_shortstatehash + .raw_aput::(room_id, shortstatehash); } /// Returns the room's version. @@ -327,9 +346,12 @@ impl Service { .map_err(|e| err!(Request(NotFound("No create event found: {e:?}")))) } - #[inline] pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { - self.db.get_room_shortstatehash(room_id).await + self.db + .roomid_shortstatehash + .get(room_id) + .await + .deserialized() } pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + '_ { @@ -346,11 +368,20 @@ impl Service { &self, room_id: &RoomId, event_ids: Vec, - state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex + _state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) { + let prefix = (room_id, Interfix); self.db - .set_forward_extremities(room_id, event_ids, state_lock) + .roomid_pduleaves + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.roomid_pduleaves.remove(key)) .await; + + for event_id in &event_ids { + let key = (room_id, event_id); + self.db.roomid_pduleaves.put_raw(key, event_id); + } } /// This fetches auth events from the current state. From 4576313a7c4a2f89fc2c2b1f04ad739dbb546b0a Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 15 Oct 2024 09:54:20 +0000 Subject: [PATCH 087/245] merge rooms user service and data Signed-off-by: Jason Volk --- src/service/rooms/user/data.rs | 108 ----------------------- src/service/rooms/user/mod.rs | 154 ++++++++++++++++++++++++--------- 2 files changed, 111 insertions(+), 151 deletions(-) delete mode 100644 src/service/rooms/user/data.rs diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs deleted file mode 100644 index 96b009f85..000000000 --- a/src/service/rooms/user/data.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Deserialized, Map}; -use futures::{Stream, StreamExt}; -use ruma::{RoomId, UserId}; - -use crate::{globals, rooms, Dep}; - -pub(super) struct Data { - userroomid_notificationcount: Arc, - userroomid_highlightcount: Arc, - roomuserid_lastnotificationread: Arc, - roomsynctoken_shortstatehash: Arc, - services: Services, -} - -struct Services { - globals: Dep, - short: Dep, - state_cache: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - userroomid_notificationcount: db["userroomid_notificationcount"].clone(), - userroomid_highlightcount: db["userroomid_highlightcount"].clone(), - roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit - roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), - services: Services { - globals: args.depend::("globals"), - short: args.depend::("rooms::short"), - state_cache: args.depend::("rooms::state_cache"), - }, - } - } - - pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { - let userroom_id = (user_id, room_id); - self.userroomid_highlightcount.put(userroom_id, 0_u64); - self.userroomid_notificationcount.put(userroom_id, 0_u64); - - let roomuser_id = (room_id, user_id); - let count = self.services.globals.next_count().unwrap(); - self.roomuserid_lastnotificationread.put(roomuser_id, count); - } - - pub(super) async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { - let key = (user_id, room_id); - self.userroomid_notificationcount - .qry(&key) - .await - .deserialized() - .unwrap_or(0) - } - - pub(super) async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { - let key = (user_id, room_id); - self.userroomid_highlightcount - .qry(&key) - .await - .deserialized() - .unwrap_or(0) - } - - pub(super) async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { - let key = (room_id, user_id); - self.roomuserid_lastnotificationread - .qry(&key) - .await - .deserialized() - .unwrap_or(0) - } - - pub(super) async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { - let shortroomid = self - .services - .short - .get_shortroomid(room_id) - .await - .expect("room exists"); - - let key: &[u64] = &[shortroomid, token]; - self.roomsynctoken_shortstatehash.put(key, shortstatehash); - } - - pub(super) async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { - let shortroomid = self.services.short.get_shortroomid(room_id).await?; - - let key: &[u64] = &[shortroomid, token]; - self.roomsynctoken_shortstatehash - .qry(key) - .await - .deserialized() - } - - //TODO: optimize; replace point-queries with dual iteration - pub(super) fn get_shared_rooms<'a>( - &'a self, user_a: &'a UserId, user_b: &'a UserId, - ) -> impl Stream + Send + 'a { - self.services - .state_cache - .rooms_joined(user_a) - .filter(|room_id| self.services.state_cache.is_joined(user_b, room_id)) - } -} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index d9d90ecf9..e484203d5 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,71 +1,139 @@ -mod data; - use std::sync::Arc; -use conduit::Result; +use conduit::{implement, Result}; +use database::{Deserialized, Map}; use futures::{pin_mut, Stream, StreamExt}; use ruma::{RoomId, UserId}; -use self::data::Data; +use crate::{globals, rooms, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + userroomid_notificationcount: Arc, + userroomid_highlightcount: Arc, + roomuserid_lastnotificationread: Arc, + roomsynctoken_shortstatehash: Arc, +} + +struct Services { + globals: Dep, + short: Dep, + state_cache: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), + userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(), + roomuserid_lastnotificationread: args.db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit + roomsynctoken_shortstatehash: args.db["roomsynctoken_shortstatehash"].clone(), + }, + + services: Services { + globals: args.depend::("globals"), + short: args.depend::("rooms::short"), + state_cache: args.depend::("rooms::state_cache"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[inline] - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { - self.db.reset_notification_counts(user_id, room_id); - } +#[implement(Service)] +pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + self.db.userroomid_highlightcount.put(userroom_id, 0_u64); + self.db.userroomid_notificationcount.put(userroom_id, 0_u64); - #[inline] - pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { - self.db.notification_count(user_id, room_id).await - } + let roomuser_id = (room_id, user_id); + let count = self.services.globals.next_count().unwrap(); + self.db + .roomuserid_lastnotificationread + .put(roomuser_id, count); +} - #[inline] - pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { - self.db.highlight_count(user_id, room_id).await - } +#[implement(Service)] +pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); + self.db + .userroomid_notificationcount + .qry(&key) + .await + .deserialized() + .unwrap_or(0) +} - #[inline] - pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { - self.db.last_notification_read(user_id, room_id).await - } +#[implement(Service)] +pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); + self.db + .userroomid_highlightcount + .qry(&key) + .await + .deserialized() + .unwrap_or(0) +} - #[inline] - pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { - self.db - .associate_token_shortstatehash(room_id, token, shortstatehash) - .await; - } +#[implement(Service)] +pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.db + .roomuserid_lastnotificationread + .qry(&key) + .await + .deserialized() + .unwrap_or(0) +} - #[inline] - pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { - self.db.get_token_shortstatehash(room_id, token).await - } +#[implement(Service)] +pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { + let shortroomid = self + .services + .short + .get_shortroomid(room_id) + .await + .expect("room exists"); - #[inline] - pub fn get_shared_rooms<'a>( - &'a self, user_a: &'a UserId, user_b: &'a UserId, - ) -> impl Stream + Send + 'a { - self.db.get_shared_rooms(user_a, user_b) - } + let key: &[u64] = &[shortroomid, token]; + self.db + .roomsynctoken_shortstatehash + .put(key, shortstatehash); +} - pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool { - let get_shared_rooms = self.get_shared_rooms(user_a, user_b); +#[implement(Service)] +pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { + let shortroomid = self.services.short.get_shortroomid(room_id).await?; - pin_mut!(get_shared_rooms); - get_shared_rooms.next().await.is_some() - } + let key: &[u64] = &[shortroomid, token]; + self.db + .roomsynctoken_shortstatehash + .qry(key) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool { + let get_shared_rooms = self.get_shared_rooms(user_a, user_b); + + pin_mut!(get_shared_rooms); + get_shared_rooms.next().await.is_some() +} + +//TODO: optimize; replace point-queries with dual iteration +#[implement(Service)] +pub fn get_shared_rooms<'a>( + &'a self, user_a: &'a UserId, user_b: &'a UserId, +) -> impl Stream + Send + 'a { + self.services + .state_cache + .rooms_joined(user_a) + .filter(|room_id| self.services.state_cache.is_joined(user_b, room_id)) } From 0b085ea84fae3125d956e874cf082532e9f57ca8 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 15 Oct 2024 10:34:37 +0000 Subject: [PATCH 088/245] merge remaining rooms state_cache data and service Signed-off-by: Jason Volk --- src/service/rooms/state_cache/data.rs | 179 -------------------------- src/service/rooms/state_cache/mod.rs | 178 ++++++++++++++++++++----- 2 files changed, 144 insertions(+), 213 deletions(-) delete mode 100644 src/service/rooms/state_cache/data.rs diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs deleted file mode 100644 index c06c8107f..000000000 --- a/src/service/rooms/state_cache/data.rs +++ /dev/null @@ -1,179 +0,0 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; - -use conduit::{utils::stream::TryIgnore, Result}; -use database::{serialize_to_vec, Deserialized, Interfix, Json, Map}; -use futures::{Stream, StreamExt}; -use ruma::{ - events::{AnyStrippedStateEvent, AnySyncStateEvent}, - serde::Raw, - OwnedRoomId, RoomId, UserId, -}; - -use crate::{globals, Dep}; - -type AppServiceInRoomCache = RwLock>>; -type StrippedStateEventItem = (OwnedRoomId, Vec>); -type SyncStateEventItem = (OwnedRoomId, Vec>); - -pub(super) struct Data { - pub(super) appservice_in_room_cache: AppServiceInRoomCache, - pub(super) roomid_invitedcount: Arc, - pub(super) roomid_inviteviaservers: Arc, - pub(super) roomid_joinedcount: Arc, - pub(super) roomserverids: Arc, - pub(super) roomuserid_invitecount: Arc, - pub(super) roomuserid_joined: Arc, - pub(super) roomuserid_leftcount: Arc, - pub(super) roomuseroncejoinedids: Arc, - pub(super) serverroomids: Arc, - pub(super) userroomid_invitestate: Arc, - pub(super) userroomid_joined: Arc, - pub(super) userroomid_leftstate: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - appservice_in_room_cache: RwLock::new(HashMap::new()), - roomid_invitedcount: db["roomid_invitedcount"].clone(), - roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(), - roomid_joinedcount: db["roomid_joinedcount"].clone(), - roomserverids: db["roomserverids"].clone(), - roomuserid_invitecount: db["roomuserid_invitecount"].clone(), - roomuserid_joined: db["roomuserid_joined"].clone(), - roomuserid_leftcount: db["roomuserid_leftcount"].clone(), - roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(), - serverroomids: db["serverroomids"].clone(), - userroomid_invitestate: db["userroomid_invitestate"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), - userroomid_leftstate: db["userroomid_leftstate"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { - let key = (user_id, room_id); - - self.roomuseroncejoinedids.put_raw(key, []); - } - - pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { - let userroom_id = (user_id, room_id); - let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); - - let roomuser_id = (room_id, user_id); - let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); - - self.userroomid_joined.insert(&userroom_id, []); - self.roomuserid_joined.insert(&roomuser_id, []); - - self.userroomid_invitestate.remove(&userroom_id); - self.roomuserid_invitecount.remove(&roomuser_id); - - self.userroomid_leftstate.remove(&userroom_id); - self.roomuserid_leftcount.remove(&roomuser_id); - - self.roomid_inviteviaservers.remove(room_id); - } - - pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { - let userroom_id = (user_id, room_id); - let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); - - let roomuser_id = (room_id, user_id); - let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); - - // (timo) TODO - let leftstate = Vec::>::new(); - let count = self.services.globals.next_count().unwrap(); - - self.userroomid_leftstate - .raw_put(&userroom_id, Json(leftstate)); - self.roomuserid_leftcount.raw_put(&roomuser_id, count); - - self.userroomid_joined.remove(&userroom_id); - self.roomuserid_joined.remove(&roomuser_id); - - self.userroomid_invitestate.remove(&userroom_id); - self.roomuserid_invitecount.remove(&roomuser_id); - - self.roomid_inviteviaservers.remove(room_id); - } - - /// Makes a user forget a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) { - let userroom_id = (user_id, room_id); - let roomuser_id = (room_id, user_id); - - self.userroomid_leftstate.del(userroom_id); - self.roomuserid_leftcount.del(roomuser_id); - } - - /// Returns an iterator over all rooms a user was invited to. - #[inline] - pub(super) fn rooms_invited<'a>( - &'a self, user_id: &'a UserId, - ) -> impl Stream + Send + 'a { - type Key<'a> = (&'a UserId, &'a RoomId); - type KeyVal<'a> = (Key<'a>, Raw>); - - let prefix = (user_id, Interfix); - self.userroomid_invitestate - .stream_prefix(&prefix) - .ignore_err() - .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) - .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) - .ignore_err() - } - - /// Returns an iterator over all rooms a user left. - #[inline] - pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { - type Key<'a> = (&'a UserId, &'a RoomId); - type KeyVal<'a> = (Key<'a>, Raw>>); - - let prefix = (user_id, Interfix); - self.userroomid_leftstate - .stream_prefix(&prefix) - .ignore_err() - .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) - .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) - .ignore_err() - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) async fn invite_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>> { - let key = (user_id, room_id); - self.userroomid_invitestate - .qry(&key) - .await - .deserialized() - .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) async fn left_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>> { - let key = (user_id, room_id); - self.userroomid_leftstate - .qry(&key) - .await - .deserialized() - .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) - } -} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 077eee104..4f4ff2646 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,14 +1,14 @@ -mod data; - -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, RwLock}, +}; use conduit::{ err, is_not_empty, utils::{stream::TryIgnore, ReadyExt, StreamTools}, warn, Result, }; -use data::Data; -use database::{serialize_to_vec, Deserialized, Ignore, Interfix, Json}; +use database::{serialize_to_vec, Deserialized, Ignore, Interfix, Json, Map}; use futures::{stream::iter, Stream, StreamExt}; use itertools::Itertools; use ruma::{ @@ -29,6 +29,7 @@ use ruma::{ use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep}; pub struct Service { + appservice_in_room_cache: AppServiceInRoomCache, services: Services, db: Data, } @@ -40,16 +41,49 @@ struct Services { users: Dep, } +struct Data { + roomid_invitedcount: Arc, + roomid_inviteviaservers: Arc, + roomid_joinedcount: Arc, + roomserverids: Arc, + roomuserid_invitecount: Arc, + roomuserid_joined: Arc, + roomuserid_leftcount: Arc, + roomuseroncejoinedids: Arc, + serverroomids: Arc, + userroomid_invitestate: Arc, + userroomid_joined: Arc, + userroomid_leftstate: Arc, +} + +type AppServiceInRoomCache = RwLock>>; +type StrippedStateEventItem = (OwnedRoomId, Vec>); +type SyncStateEventItem = (OwnedRoomId, Vec>); + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + appservice_in_room_cache: RwLock::new(HashMap::new()), services: Services { account_data: args.depend::("account_data"), globals: args.depend::("globals"), state_accessor: args.depend::("rooms::state_accessor"), users: args.depend::("users"), }, - db: Data::new(&args), + db: Data { + roomid_invitedcount: args.db["roomid_invitedcount"].clone(), + roomid_inviteviaservers: args.db["roomid_inviteviaservers"].clone(), + roomid_joinedcount: args.db["roomid_joinedcount"].clone(), + roomserverids: args.db["roomserverids"].clone(), + roomuserid_invitecount: args.db["roomuserid_invitecount"].clone(), + roomuserid_joined: args.db["roomuserid_joined"].clone(), + roomuserid_leftcount: args.db["roomuserid_leftcount"].clone(), + roomuseroncejoinedids: args.db["roomuseroncejoinedids"].clone(), + serverroomids: args.db["serverroomids"].clone(), + userroomid_invitestate: args.db["userroomid_invitestate"].clone(), + userroomid_joined: args.db["userroomid_joined"].clone(), + userroomid_leftstate: args.db["userroomid_leftstate"].clone(), + }, })) } @@ -107,7 +141,7 @@ impl Service { // Check if the user never joined this room if !self.once_joined(user_id, room_id).await { // Add the user ID to the join list then - self.db.mark_as_once_joined(user_id, room_id); + self.mark_as_once_joined(user_id, room_id); // Check if the room has a predecessor if let Ok(Some(predecessor)) = self @@ -186,7 +220,7 @@ impl Service { } } - self.db.mark_as_joined(user_id, room_id); + self.mark_as_joined(user_id, room_id); }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver @@ -198,7 +232,7 @@ impl Service { .await; }, MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id); + self.mark_as_left(user_id, room_id); }, _ => {}, } @@ -213,10 +247,9 @@ impl Service { #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool { let maybe = self - .db .appservice_in_room_cache .read() - .unwrap() + .expect("locked") .get(room_id) .and_then(|map| map.get(&appservice.registration.id)) .copied(); @@ -242,10 +275,9 @@ impl Service { .ready_any(|userid| appservice.users.is_match(userid.as_str())) .await; - self.db - .appservice_in_room_cache + self.appservice_in_room_cache .write() - .unwrap() + .expect("locked") .entry(room_id.to_owned()) .or_default() .insert(appservice.registration.id.clone(), in_room); @@ -254,21 +286,67 @@ impl Service { } } - /// Direct DB function to directly mark a user as left. It is not + /// Direct DB function to directly mark a user as joined. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_left(user_id, room_id); } + pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); - /// Direct DB function to directly mark a user as joined. It is not + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); + + self.db.userroomid_joined.insert(&userroom_id, []); + self.db.roomuserid_joined.insert(&roomuser_id, []); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); + } + + /// Direct DB function to directly mark a user as left. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_joined(user_id, room_id); } + pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); + + // (timo) TODO + let leftstate = Vec::>::new(); + let count = self.services.globals.next_count().unwrap(); + + self.db + .userroomid_leftstate + .raw_put(&userroom_id, Json(leftstate)); + self.db.roomuserid_leftcount.raw_put(&roomuser_id, count); + + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); + } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { self.db.forget(room_id, user_id); } + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { + let userroom_id = (user_id, room_id); + let roomuser_id = (room_id, user_id); + + self.db.userroomid_leftstate.del(userroom_id); + self.db.roomuserid_leftcount.del(roomuser_id); + } /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self), level = "debug")] @@ -415,28 +493,56 @@ impl Service { /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_invited<'a>( - &'a self, user_id: &'a UserId, - ) -> impl Stream>)> + Send + 'a { - self.db.rooms_invited(user_id) + pub fn rooms_invited<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_invitestate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() } #[tracing::instrument(skip(self), level = "debug")] pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { - self.db.invite_state(user_id, room_id).await + let key = (user_id, room_id); + self.db + .userroomid_invitestate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) } #[tracing::instrument(skip(self), level = "debug")] pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { - self.db.left_state(user_id, room_id).await + let key = (user_id, room_id); + self.db + .userroomid_leftstate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) } /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_left<'a>( - &'a self, user_id: &'a UserId, - ) -> impl Stream>)> + Send + 'a { - self.db.rooms_left(user_id) + pub fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw>>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_leftstate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() } #[tracing::instrument(skip(self), level = "debug")] @@ -515,13 +621,13 @@ impl Service { } pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { - let cache = self.db.appservice_in_room_cache.read().expect("locked"); + let cache = self.appservice_in_room_cache.read().expect("locked"); + (cache.len(), cache.capacity()) } pub fn clear_appservice_in_room_cache(&self) { - self.db - .appservice_in_room_cache + self.appservice_in_room_cache .write() .expect("locked") .clear(); @@ -574,13 +680,17 @@ impl Service { self.db.serverroomids.put_raw(serverroom_id, []); } - self.db - .appservice_in_room_cache + self.appservice_in_room_cache .write() .expect("locked") .remove(room_id); } + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.put_raw(key, []); + } + pub async fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, invite_via: Option>, From 84191656fb1e8340afb1ace819e48537a77c053e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 16 Oct 2024 02:31:36 +0000 Subject: [PATCH 089/245] slightly cleanup appservice_in_room Signed-off-by: Jason Volk --- src/service/rooms/state_cache/mod.rs | 55 +++++++++++++--------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 4f4ff2646..11684eab4 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -5,6 +5,7 @@ use std::{ use conduit::{ err, is_not_empty, + result::LogErr, utils::{stream::TryIgnore, ReadyExt, StreamTools}, warn, Result, }; @@ -246,44 +247,40 @@ impl Service { #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool { - let maybe = self + if let Some(cached) = self .appservice_in_room_cache .read() .expect("locked") .get(room_id) .and_then(|map| map.get(&appservice.registration.id)) - .copied(); + .copied() + { + return cached; + } - if let Some(b) = maybe { - b - } else { - let bridge_user_id = UserId::parse_with_server_name( - appservice.registration.sender_localpart.as_str(), - self.services.globals.server_name(), - ) - .ok(); - - let in_room = if let Some(id) = &bridge_user_id { - self.is_joined(id, room_id).await - } else { - false - }; + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + self.services.globals.server_name(), + ); - let in_room = in_room - || self - .room_members(room_id) - .ready_any(|userid| appservice.users.is_match(userid.as_str())) - .await; + let Ok(bridge_user_id) = bridge_user_id.log_err() else { + return false; + }; - self.appservice_in_room_cache - .write() - .expect("locked") - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.registration.id.clone(), in_room); + let in_room = self.is_joined(&bridge_user_id, room_id).await + || self + .room_members(room_id) + .ready_any(|user_id| appservice.users.is_match(user_id.as_str())) + .await; - in_room - } + self.appservice_in_room_cache + .write() + .expect("locked") + .entry(room_id.into()) + .or_default() + .insert(appservice.registration.id.clone(), in_room); + + in_room } /// Direct DB function to directly mark a user as joined. It is not From 55b8908894303f1e784a881e879d5f8d6773abaf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 16 Oct 2024 03:12:30 +0000 Subject: [PATCH 090/245] merge rooms state_compressor service and data Signed-off-by: Jason Volk --- src/service/rooms/state_compressor/data.rs | 84 ------------ src/service/rooms/state_compressor/mod.rs | 142 ++++++++++++++------- 2 files changed, 97 insertions(+), 129 deletions(-) delete mode 100644 src/service/rooms/state_compressor/data.rs diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs deleted file mode 100644 index cb0204705..000000000 --- a/src/service/rooms/state_compressor/data.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::{collections::HashSet, mem::size_of, sync::Arc}; - -use conduit::{err, expected, utils, Result}; -use database::{Database, Map}; - -use super::CompressedStateEvent; - -pub(super) struct StateDiff { - pub(super) parent: Option, - pub(super) added: Arc>, - pub(super) removed: Arc>, -} - -pub(super) struct Data { - shortstatehash_statediff: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - shortstatehash_statediff: db["shortstatehash_statediff"].clone(), - } - } - - pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result { - const BUFSIZE: usize = size_of::(); - - let value = self - .shortstatehash_statediff - .aqry::(&shortstatehash) - .await - .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; - - let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); - let parent = if parent != 0 { - Some(parent) - } else { - None - }; - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let stride = size_of::(); - let mut i = stride; - while let Some(v) = value.get(i..expected!(i + 2 * stride)) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i = expected!(i + stride); - continue; - } - if add_mode { - added.insert(v.try_into().expect("we checked the size above")); - } else { - removed.insert(v.try_into().expect("we checked the size above")); - } - i = expected!(i + 2 * stride); - } - - Ok(StateDiff { - parent, - added: Arc::new(added), - removed: Arc::new(removed), - }) - } - - pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { - let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); - for new in diff.added.iter() { - value.extend_from_slice(&new[..]); - } - - if !diff.removed.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in diff.removed.iter() { - value.extend_from_slice(&removed[..]); - } - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value); - } -} diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index cd3f2f738..be66c5970 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,53 +1,21 @@ -mod data; - use std::{ collections::HashSet, fmt::Write, mem::size_of, - sync::{Arc, Mutex as StdMutex, Mutex}, + sync::{Arc, Mutex}, }; -use conduit::{checked, utils, utils::math::usize_from_f64, Result}; -use data::Data; +use conduit::{checked, err, expected, utils, utils::math::usize_from_f64, Result}; +use database::Map; use lru_cache::LruCache; use ruma::{EventId, RoomId}; -use self::data::StateDiff; use crate::{rooms, Dep}; -type StateInfoLruCache = Mutex< - LruCache< - u64, - Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed - )>, - >, ->; - -type ShortStateInfoResult = Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed -)>; - -type ParentStatesVec = Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed -)>; - -type HashSetCompressStateEvent = (u64, Arc>, Arc>); -pub type CompressedStateEvent = [u8; 2 * size_of::()]; - pub struct Service { + pub stateinfo_cache: Mutex, db: Data, services: Services, - pub stateinfo_cache: StateInfoLruCache, } struct Services { @@ -55,17 +23,42 @@ struct Services { state: Dep, } +struct Data { + shortstatehash_statediff: Arc, +} + +struct StateDiff { + parent: Option, + added: Arc>, + removed: Arc>, +} + +type StateInfoLruCache = LruCache; +type ShortStateInfoVec = Vec; +type ParentStatesVec = Vec; +type ShortStateInfo = ( + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed +); + +type HashSetCompressStateEvent = (u64, Arc>, Arc>); +pub type CompressedStateEvent = [u8; 2 * size_of::()]; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier; Ok(Arc::new(Self { - db: Data::new(args.db), + stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(), + db: Data { + shortstatehash_statediff: args.db["shortstatehash_statediff"].clone(), + }, services: Services { short: args.depend::("rooms::short"), state: args.depend::("rooms::state"), }, - stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), })) } @@ -84,7 +77,7 @@ impl crate::Service for Service { impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { + pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { if let Some(r) = self .stateinfo_cache .lock() @@ -98,7 +91,7 @@ impl Service { parent, added, removed, - } = self.db.get_statediff(shortstatehash).await?; + } = self.get_statediff(shortstatehash).await?; if let Some(parent) = parent { let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; @@ -177,12 +170,12 @@ impl Service { /// for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer - #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states), level = "debug")] + #[tracing::instrument(skip_all, level = "debug")] pub fn save_state_from_diff( &self, shortstatehash: u64, statediffnew: Arc>, statediffremoved: Arc>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, - ) -> Result<()> { + ) -> Result { let statediffnew_len = statediffnew.len(); let statediffremoved_len = statediffremoved.len(); let diffsum = checked!(statediffnew_len + statediffremoved_len)?; @@ -226,7 +219,7 @@ impl Service { if parent_states.is_empty() { // There is no parent layer, create a new state - self.db.save_statediff( + self.save_statediff( shortstatehash, &StateDiff { parent: None, @@ -279,7 +272,7 @@ impl Service { )?; } else { // Diff small enough, we add diff as layer on top of parent - self.db.save_statediff( + self.save_statediff( shortstatehash, &StateDiff { parent: Some(parent.0), @@ -324,7 +317,7 @@ impl Service { let states_parents = if let Some(p) = previous_shortstatehash { self.load_shortstatehash_info(p).await.unwrap_or_default() } else { - ShortStateInfoResult::new() + ShortStateInfoVec::new() }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -356,4 +349,63 @@ impl Service { Ok((new_shortstatehash, statediffnew, statediffremoved)) } + + async fn get_statediff(&self, shortstatehash: u64) -> Result { + const BUFSIZE: usize = size_of::(); + const STRIDE: usize = size_of::(); + + let value = self + .db + .shortstatehash_statediff + .aqry::(&shortstatehash) + .await + .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; + + let parent = utils::u64_from_bytes(&value[0..size_of::()]) + .ok() + .take_if(|parent| *parent != 0); + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = STRIDE; + while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i = expected!(i + STRIDE); + continue; + } + if add_mode { + added.insert(v.try_into()?); + } else { + removed.insert(v.try_into()?); + } + i = expected!(i + 2 * STRIDE); + } + + Ok(StateDiff { + parent, + added: Arc::new(added), + removed: Arc::new(removed), + }) + } + + fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { + let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); + for new in diff.added.iter() { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in diff.removed.iter() { + value.extend_from_slice(&removed[..]); + } + } + + self.db + .shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value); + } } From 828cb96ba9dd8e323a37c0e33d8a6e1c23b84ac5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 16 Oct 2024 05:32:27 +0000 Subject: [PATCH 091/245] split client/sync Signed-off-by: Jason Volk --- src/api/client/sync/mod.rs | 67 ++ src/api/client/{sync.rs => sync/v3.rs} | 835 +------------------------ src/api/client/sync/v4.rs | 784 +++++++++++++++++++++++ src/core/utils/mod.rs | 10 + 4 files changed, 870 insertions(+), 826 deletions(-) create mode 100644 src/api/client/sync/mod.rs rename src/api/client/{sync.rs => sync/v3.rs} (56%) create mode 100644 src/api/client/sync/v4.rs diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs new file mode 100644 index 000000000..0cfc7b8b3 --- /dev/null +++ b/src/api/client/sync/mod.rs @@ -0,0 +1,67 @@ +mod v3; +mod v4; + +use conduit::{ + utils::{math::usize_from_u64_truncated, ReadyExt}, + PduCount, +}; +use futures::StreamExt; +use ruma::{RoomId, UserId}; + +pub(crate) use self::{v3::sync_events_route, v4::sync_events_v4_route}; +use crate::{service::Services, Error, PduEvent, Result}; + +async fn load_timeline( + services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, +) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { + let timeline_pdus; + let limited = if services + .rooms + .timeline + .last_timeline_count(sender_user, room_id) + .await? + > roomsincecount + { + let mut non_timeline_pdus = services + .rooms + .timeline + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .ready_take_while(|(pducount, _)| pducount > &roomsincecount); + + // Take the last events for the timeline + timeline_pdus = non_timeline_pdus + .by_ref() + .take(usize_from_u64_truncated(limit)) + .collect::>() + .await + .into_iter() + .rev() + .collect::>(); + + // They /sync response doesn't always return all messages, so we say the output + // is limited unless there are events in non_timeline_pdus + non_timeline_pdus.next().await.is_some() + } else { + timeline_pdus = Vec::new(); + false + }; + Ok((timeline_pdus, limited)) +} + +async fn share_encrypted_room( + services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: Option<&RoomId>, +) -> bool { + services + .rooms + .user + .get_shared_rooms(sender_user, user_id) + .ready_filter(|&room_id| Some(room_id) != ignore_room) + .any(|other_room_id| { + services + .rooms + .state_accessor + .is_encrypted_room(other_room_id) + }) + .await +} diff --git a/src/api/client/sync.rs b/src/api/client/sync/v3.rs similarity index 56% rename from src/api/client/sync.rs rename to src/api/client/sync/v3.rs index 65af775d4..f29fe220e 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync/v3.rs @@ -1,23 +1,19 @@ use std::{ - cmp::{self, Ordering}, - collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, + cmp::{self}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, time::Duration, }; use axum::extract::State; use conduit::{ - debug, err, error, is_equal_to, + err, error, extract_variant, is_equal_to, result::FlatOk, - utils::{ - math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, - BoolExt, IterStream, ReadyExt, TryFutureExtExt, - }, - warn, PduCount, + utils::{math::ruma_from_u64, BoolExt, IterStream, ReadyExt, TryFutureExtExt}, + PduCount, }; -use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt, TryFutureExt}; +use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt}; use ruma::{ api::client::{ - error::ErrorKind, filter::{FilterDefinition, LazyLoadOptions}, sync::sync_events::{ self, @@ -25,43 +21,27 @@ use ruma::{ Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State as RoomState, Timeline, ToDevice, }, - v4::{SlidingOp, SlidingSyncRoomHero}, DeviceLists, UnreadNotificationsCount, }, uiaa::UiaaResponse, }, - directory::RoomTypeFilter, events::{ presence::PresenceEvent, room::member::{MembershipState, RoomMemberEventContent}, AnyRawAccountDataEvent, StateEventType, - TimelineEventType::{self, *}, + TimelineEventType::*, }, serde::Raw, - state_res::Event, - uint, DeviceId, EventId, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId, + uint, DeviceId, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId, }; -use service::rooms::read_receipt::pack_receipts; use tracing::{Instrument as _, Span}; +use super::{load_timeline, share_encrypted_room}; use crate::{ service::{pdu::EventHash, Services}, utils, Error, PduEvent, Result, Ruma, RumaResponse, }; -const SINGLE_CONNECTION_SYNC: &str = "single_connection_sync"; -const DEFAULT_BUMP_TYPES: &[TimelineEventType; 6] = - &[RoomMessage, RoomEncrypted, Sticker, CallInvite, PollStart, Beacon]; - -macro_rules! extract_variant { - ($e:expr, $variant:path) => { - match $e { - $variant(value) => Some(value), - _ => None, - } - }; -} - /// # `GET /_matrix/client/r0/sync` /// /// Synchronize the client's state with the latest state on the server. @@ -1085,800 +1065,3 @@ async fn load_joined_room( unread_thread_notifications: BTreeMap::new(), }) } - -async fn load_timeline( - services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, -) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { - let timeline_pdus; - let limited = if services - .rooms - .timeline - .last_timeline_count(sender_user, room_id) - .await? - > roomsincecount - { - let mut non_timeline_pdus = services - .rooms - .timeline - .pdus_until(sender_user, room_id, PduCount::max()) - .await? - .ready_take_while(|(pducount, _)| pducount > &roomsincecount); - - // Take the last events for the timeline - timeline_pdus = non_timeline_pdus - .by_ref() - .take(usize_from_u64_truncated(limit)) - .collect::>() - .await - .into_iter() - .rev() - .collect::>(); - - // They /sync response doesn't always return all messages, so we say the output - // is limited unless there are events in non_timeline_pdus - non_timeline_pdus.next().await.is_some() - } else { - timeline_pdus = Vec::new(); - false - }; - Ok((timeline_pdus, limited)) -} - -async fn share_encrypted_room( - services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: Option<&RoomId>, -) -> bool { - services - .rooms - .user - .get_shared_rooms(sender_user, user_id) - .ready_filter(|&room_id| Some(room_id) != ignore_room) - .any(|other_room_id| { - services - .rooms - .state_accessor - .is_encrypted_room(other_room_id) - }) - .await -} - -/// POST `/_matrix/client/unstable/org.matrix.msc3575/sync` -/// -/// Sliding Sync endpoint (future endpoint: `/_matrix/client/v4/sync`) -pub(crate) async fn sync_events_v4_route( - State(services): State, body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.expect("user is authenticated"); - let mut body = body.body; - // Setup watchers, so if there's no response, we can wait for them - let watcher = services.globals.watch(sender_user, &sender_device); - - let next_batch = services.globals.next_count()?; - - let conn_id = body - .conn_id - .clone() - .unwrap_or_else(|| SINGLE_CONNECTION_SYNC.to_owned()); - - let globalsince = body - .pos - .as_ref() - .and_then(|string| string.parse().ok()) - .unwrap_or(0); - - if globalsince != 0 - && !services - .sync - .remembered(sender_user.clone(), sender_device.clone(), conn_id.clone()) - { - debug!("Restarting sync stream because it was gone from the database"); - return Err(Error::Request( - ErrorKind::UnknownPos, - "Connection data lost since last time".into(), - http::StatusCode::BAD_REQUEST, - )); - } - - if globalsince == 0 { - services - .sync - .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); - } - - // Get sticky parameters from cache - let known_rooms = - services - .sync - .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); - - let all_joined_rooms: Vec<_> = services - .rooms - .state_cache - .rooms_joined(sender_user) - .map(ToOwned::to_owned) - .collect() - .await; - - let all_invited_rooms: Vec<_> = services - .rooms - .state_cache - .rooms_invited(sender_user) - .map(|r| r.0) - .collect() - .await; - - let all_rooms = all_joined_rooms - .iter() - .chain(all_invited_rooms.iter()) - .map(Clone::clone) - .collect(); - - if body.extensions.to_device.enabled.unwrap_or(false) { - services - .users - .remove_to_device_events(sender_user, &sender_device, globalsince) - .await; - } - - let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in - let mut device_list_changes = HashSet::new(); - let mut device_list_left = HashSet::new(); - - let mut receipts = sync_events::v4::Receipts { - rooms: BTreeMap::new(), - }; - - let mut account_data = sync_events::v4::AccountData { - global: Vec::new(), - rooms: BTreeMap::new(), - }; - if body.extensions.account_data.enabled.unwrap_or(false) { - account_data.global = services - .account_data - .changes_since(None, sender_user, globalsince) - .await? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) - .collect(); - - if let Some(rooms) = body.extensions.account_data.rooms { - for room in rooms { - account_data.rooms.insert( - room.clone(), - services - .account_data - .changes_since(Some(&room), sender_user, globalsince) - .await? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) - .collect(), - ); - } - } - } - - if body.extensions.e2ee.enabled.unwrap_or(false) { - // Look for device list updates of this account - device_list_changes.extend( - services - .users - .keys_changed(sender_user.as_ref(), globalsince, None) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); - - for room_id in &all_joined_rooms { - let Ok(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id).await else { - error!("Room {room_id} has no state"); - continue; - }; - - let since_shortstatehash = services - .rooms - .user - .get_token_shortstatehash(room_id, globalsince) - .await - .ok(); - - let encrypted_room = services - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") - .await - .is_ok(); - - if let Some(since_shortstatehash) = since_shortstatehash { - // Skip if there are only timeline changes - if since_shortstatehash == current_shortstatehash { - continue; - } - - let since_encryption = services - .rooms - .state_accessor - .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") - .await; - - let since_sender_member: Option = services - .rooms - .state_accessor - .state_get_content(since_shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .ok() - .await; - - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - - let new_encrypted_room = encrypted_room && since_encryption.is_err(); - - if encrypted_room { - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - - let since_state_ids = services - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; - - for (key, id) in current_state_ids { - if since_state_ids.get(&key) != Some(&id) { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); - continue; - }; - if pdu.kind == RoomMember { - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - if user_id == *sender_user { - continue; - } - - let content: RoomMemberEventContent = pdu.get_content()?; - match content.membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(&services, sender_user, &user_id, Some(room_id)) - .await - { - device_list_changes.insert(user_id); - } - }, - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - }, - _ => {}, - } - } - } - } - } - if joined_since_last_sync || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_changes.extend( - services - .rooms - .state_cache - .room_members(room_id) - // Don't send key updates from the sender to the sender - .ready_filter(|user_id| sender_user != user_id) - // Only send keys if the sender doesn't share an encrypted room with the target - // already - .filter_map(|user_id| { - share_encrypted_room(&services, sender_user, user_id, Some(room_id)) - .map(|res| res.or_some(user_id.to_owned())) - }) - .collect::>() - .await, - ); - } - } - } - // Look for device list updates in this room - device_list_changes.extend( - services - .users - .keys_changed(room_id.as_ref(), globalsince, None) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); - } - - for user_id in left_encrypted_users { - let dont_share_encrypted_room = !share_encrypted_room(&services, sender_user, &user_id, None).await; - - // If the user doesn't share an encrypted room with the target anymore, we need - // to tell them - if dont_share_encrypted_room { - device_list_left.insert(user_id); - } - } - } - - let mut lists = BTreeMap::new(); - let mut todo_rooms = BTreeMap::new(); // and required state - - for (list_id, list) in &body.lists { - let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) { - Some(true) => &all_invited_rooms, - Some(false) => &all_joined_rooms, - None => &all_rooms, - }; - - let active_rooms = match list.filters.clone().map(|f| f.not_room_types) { - Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(active_rooms, State(services), &value, true).await, - None => active_rooms.clone(), - }; - - let active_rooms = match list.filters.clone().map(|f| f.room_types) { - Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(&active_rooms, State(services), &value, false).await, - None => active_rooms, - }; - - let mut new_known_rooms = BTreeSet::new(); - - let ranges = list.ranges.clone(); - lists.insert( - list_id.clone(), - sync_events::v4::SyncList { - ops: ranges - .into_iter() - .map(|mut r| { - r.0 = r.0.clamp( - uint!(0), - UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX), - ); - r.1 = - r.1.clamp(r.0, UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX)); - - let room_ids = if !active_rooms.is_empty() { - active_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec() - } else { - Vec::new() - }; - - new_known_rooms.extend(room_ids.iter().cloned()); - for room_id in &room_ids { - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); - - let limit = list - .room_details - .timeline_limit - .map_or(10, u64::from) - .min(100); - - todo_room - .0 - .extend(list.room_details.required_state.iter().cloned()); - - todo_room.1 = todo_room.1.max(limit); - // 0 means unknown because it got out of date - todo_room.2 = todo_room.2.min( - known_rooms - .get(list_id.as_str()) - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - } - sync_events::v4::SyncOp { - op: SlidingOp::Sync, - range: Some(r), - index: None, - room_ids, - room_id: None, - } - }) - .collect(), - count: ruma_from_usize(active_rooms.len()), - }, - ); - - if let Some(conn_id) = &body.conn_id { - services.sync.update_sync_known_rooms( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - list_id.clone(), - new_known_rooms, - globalsince, - ); - } - } - - let mut known_subscription_rooms = BTreeSet::new(); - for (room_id, room) in &body.room_subscriptions { - if !services.rooms.metadata.exists(room_id).await { - continue; - } - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); - let limit = room.timeline_limit.map_or(10, u64::from).min(100); - todo_room.0.extend(room.required_state.iter().cloned()); - todo_room.1 = todo_room.1.max(limit); - // 0 means unknown because it got out of date - todo_room.2 = todo_room.2.min( - known_rooms - .get("subscriptions") - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - known_subscription_rooms.insert(room_id.clone()); - } - - for r in body.unsubscribe_rooms { - known_subscription_rooms.remove(&r); - body.room_subscriptions.remove(&r); - } - - if let Some(conn_id) = &body.conn_id { - services.sync.update_sync_known_rooms( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - "subscriptions".to_owned(), - known_subscription_rooms, - globalsince, - ); - } - - if let Some(conn_id) = &body.conn_id { - services.sync.update_sync_subscriptions( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - body.room_subscriptions, - ); - } - - let mut rooms = BTreeMap::new(); - for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { - let roomsincecount = PduCount::Normal(*roomsince); - - let mut timestamp: Option<_> = None; - let mut invite_state = None; - let (timeline_pdus, limited); - if all_invited_rooms.contains(room_id) { - // TODO: figure out a timestamp we can use for remote invites - invite_state = services - .rooms - .state_cache - .invite_state(sender_user, room_id) - .await - .ok(); - - (timeline_pdus, limited) = (Vec::new(), true); - } else { - (timeline_pdus, limited) = - match load_timeline(&services, sender_user, room_id, roomsincecount, *timeline_limit).await { - Ok(value) => value, - Err(err) => { - warn!("Encountered missing timeline in {}, error {}", room_id, err); - continue; - }, - }; - } - - account_data.rooms.insert( - room_id.clone(), - services - .account_data - .changes_since(Some(room_id), sender_user, *roomsince) - .await? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) - .collect(), - ); - - let vector: Vec<_> = services - .rooms - .read_receipt - .readreceipts_since(room_id, *roomsince) - .filter_map(|(read_user, ts, v)| async move { - (!services - .users - .user_is_ignored(&read_user, sender_user) - .await) - .then_some((read_user, ts, v)) - }) - .collect() - .await; - - let receipt_size = vector.len(); - receipts - .rooms - .insert(room_id.clone(), pack_receipts(Box::new(vector.into_iter()))); - - if roomsince != &0 - && timeline_pdus.is_empty() - && account_data.rooms.get(room_id).is_some_and(Vec::is_empty) - && receipt_size == 0 - { - continue; - } - - let prev_batch = timeline_pdus - .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - }, - PduCount::Normal(c) => c.to_string(), - })) - })? - .or_else(|| { - if roomsince != &0 { - Some(roomsince.to_string()) - } else { - None - } - }); - - let room_events: Vec<_> = timeline_pdus - .iter() - .stream() - .filter_map(|(_, pdu)| async move { - // list of safe and common non-state events to ignore - if matches!( - &pdu.kind, - RoomMessage - | Sticker | CallInvite - | CallNotify | RoomEncrypted - | Image | File | Audio - | Voice | Video | UnstablePollStart - | PollStart | KeyVerificationStart - | Reaction | Emote | Location - ) && services - .users - .user_is_ignored(&pdu.sender, sender_user) - .await - { - return None; - } - - Some(pdu.to_sync_room_event()) - }) - .collect() - .await; - - for (_, pdu) in timeline_pdus { - let ts = MilliSecondsSinceUnixEpoch(pdu.origin_server_ts); - if DEFAULT_BUMP_TYPES.contains(pdu.event_type()) && !timestamp.is_some_and(|time| time > ts) { - timestamp = Some(ts); - } - } - - let required_state = required_state_request - .iter() - .stream() - .filter_map(|state| async move { - services - .rooms - .state_accessor - .room_state_get(room_id, &state.0, &state.1) - .await - .map(|s| s.to_sync_state_event()) - .ok() - }) - .collect() - .await; - - // Heroes - let heroes: Vec<_> = services - .rooms - .state_cache - .room_members(room_id) - .ready_filter(|member| member != sender_user) - .filter_map(|user_id| { - services - .rooms - .state_accessor - .get_member(room_id, user_id) - .map_ok(|memberevent| SlidingSyncRoomHero { - user_id: user_id.into(), - name: memberevent.displayname, - avatar: memberevent.avatar_url, - }) - .ok() - }) - .take(5) - .collect() - .await; - - let name = match heroes.len().cmp(&(1_usize)) { - Ordering::Greater => { - let firsts = heroes[1..] - .iter() - .map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string())) - .collect::>() - .join(", "); - - let last = heroes[0] - .name - .clone() - .unwrap_or_else(|| heroes[0].user_id.to_string()); - - Some(format!("{firsts} and {last}")) - }, - Ordering::Equal => Some( - heroes[0] - .name - .clone() - .unwrap_or_else(|| heroes[0].user_id.to_string()), - ), - Ordering::Less => None, - }; - - let heroes_avatar = if heroes.len() == 1 { - heroes[0].avatar.clone() - } else { - None - }; - - rooms.insert( - room_id.clone(), - sync_events::v4::SlidingSyncRoom { - name: services - .rooms - .state_accessor - .get_name(room_id) - .await - .ok() - .or(name), - avatar: if let Some(heroes_avatar) = heroes_avatar { - ruma::JsOption::Some(heroes_avatar) - } else { - match services.rooms.state_accessor.get_avatar(room_id).await { - ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), - ruma::JsOption::Null => ruma::JsOption::Null, - ruma::JsOption::Undefined => ruma::JsOption::Undefined, - } - }, - initial: Some(roomsince == &0), - is_dm: None, - invite_state, - unread_notifications: UnreadNotificationsCount { - highlight_count: Some( - services - .rooms - .user - .highlight_count(sender_user, room_id) - .await - .try_into() - .expect("notification count can't go that high"), - ), - notification_count: Some( - services - .rooms - .user - .notification_count(sender_user, room_id) - .await - .try_into() - .expect("notification count can't go that high"), - ), - }, - timeline: room_events, - required_state, - prev_batch, - limited, - joined_count: Some( - services - .rooms - .state_cache - .room_joined_count(room_id) - .await - .unwrap_or(0) - .try_into() - .unwrap_or_else(|_| uint!(0)), - ), - invited_count: Some( - services - .rooms - .state_cache - .room_invited_count(room_id) - .await - .unwrap_or(0) - .try_into() - .unwrap_or_else(|_| uint!(0)), - ), - num_live: None, // Count events in timeline greater than global sync counter - timestamp, - heroes: Some(heroes), - }, - ); - } - - if rooms - .iter() - .all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) - { - // Hang a few seconds so requests are not spammed - // Stop hanging if new info arrives - let default = Duration::from_secs(30); - let duration = cmp::min(body.timeout.unwrap_or(default), default); - _ = tokio::time::timeout(duration, watcher).await; - } - - Ok(sync_events::v4::Response { - initial: globalsince == 0, - txn_id: body.txn_id.clone(), - pos: next_batch.to_string(), - lists, - rooms, - extensions: sync_events::v4::Extensions { - to_device: if body.extensions.to_device.enabled.unwrap_or(false) { - Some(sync_events::v4::ToDevice { - events: services - .users - .get_to_device_events(sender_user, &sender_device) - .collect() - .await, - next_batch: next_batch.to_string(), - }) - } else { - None - }, - e2ee: sync_events::v4::E2EE { - device_lists: DeviceLists { - changed: device_list_changes.into_iter().collect(), - left: device_list_left.into_iter().collect(), - }, - device_one_time_keys_count: services - .users - .count_one_time_keys(sender_user, &sender_device) - .await, - // Fallback keys are not yet supported - device_unused_fallback_key_types: None, - }, - account_data, - receipts, - typing: sync_events::v4::Typing { - rooms: BTreeMap::new(), - }, - }, - delta_token: None, - }) -} - -async fn filter_rooms( - rooms: &[OwnedRoomId], State(services): State, filter: &[RoomTypeFilter], negate: bool, -) -> Vec { - rooms - .iter() - .stream() - .filter_map(|r| async move { - let room_type = services.rooms.state_accessor.get_room_type(r).await; - - if room_type.as_ref().is_err_and(|e| !e.is_not_found()) { - return None; - } - - let room_type_filter = RoomTypeFilter::from(room_type.ok()); - - let include = if negate { - !filter.contains(&room_type_filter) - } else { - filter.is_empty() || filter.contains(&room_type_filter) - }; - - include.then_some(r.to_owned()) - }) - .collect() - .await -} diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs new file mode 100644 index 000000000..2adb3b71a --- /dev/null +++ b/src/api/client/sync/v4.rs @@ -0,0 +1,784 @@ +use std::{ + cmp::{self, Ordering}, + collections::{BTreeMap, BTreeSet, HashSet}, + time::Duration, +}; + +use axum::extract::State; +use conduit::{ + debug, error, extract_variant, + utils::{ + math::{ruma_from_usize, usize_from_ruma}, + BoolExt, IterStream, ReadyExt, TryFutureExtExt, + }, + warn, Error, PduCount, Result, +}; +use futures::{FutureExt, StreamExt, TryFutureExt}; +use ruma::{ + api::client::{ + error::ErrorKind, + sync::sync_events::{ + self, + v4::{SlidingOp, SlidingSyncRoomHero}, + DeviceLists, UnreadNotificationsCount, + }, + }, + directory::RoomTypeFilter, + events::{ + room::member::{MembershipState, RoomMemberEventContent}, + AnyRawAccountDataEvent, StateEventType, + TimelineEventType::{self, *}, + }, + state_res::Event, + uint, MilliSecondsSinceUnixEpoch, OwnedRoomId, UInt, UserId, +}; +use service::{rooms::read_receipt::pack_receipts, Services}; + +use super::{load_timeline, share_encrypted_room}; +use crate::Ruma; + +const SINGLE_CONNECTION_SYNC: &str = "single_connection_sync"; +const DEFAULT_BUMP_TYPES: &[TimelineEventType; 6] = + &[RoomMessage, RoomEncrypted, Sticker, CallInvite, PollStart, Beacon]; + +/// POST `/_matrix/client/unstable/org.matrix.msc3575/sync` +/// +/// Sliding Sync endpoint (future endpoint: `/_matrix/client/v4/sync`) +pub(crate) async fn sync_events_v4_route( + State(services): State, body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.expect("user is authenticated"); + let mut body = body.body; + // Setup watchers, so if there's no response, we can wait for them + let watcher = services.globals.watch(sender_user, &sender_device); + + let next_batch = services.globals.next_count()?; + + let conn_id = body + .conn_id + .clone() + .unwrap_or_else(|| SINGLE_CONNECTION_SYNC.to_owned()); + + let globalsince = body + .pos + .as_ref() + .and_then(|string| string.parse().ok()) + .unwrap_or(0); + + if globalsince != 0 + && !services + .sync + .remembered(sender_user.clone(), sender_device.clone(), conn_id.clone()) + { + debug!("Restarting sync stream because it was gone from the database"); + return Err(Error::Request( + ErrorKind::UnknownPos, + "Connection data lost since last time".into(), + http::StatusCode::BAD_REQUEST, + )); + } + + if globalsince == 0 { + services + .sync + .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); + } + + // Get sticky parameters from cache + let known_rooms = + services + .sync + .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); + + let all_joined_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_joined(sender_user) + .map(ToOwned::to_owned) + .collect() + .await; + + let all_invited_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_invited(sender_user) + .map(|r| r.0) + .collect() + .await; + + let all_rooms = all_joined_rooms + .iter() + .chain(all_invited_rooms.iter()) + .map(Clone::clone) + .collect(); + + if body.extensions.to_device.enabled.unwrap_or(false) { + services + .users + .remove_to_device_events(sender_user, &sender_device, globalsince) + .await; + } + + let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in + let mut device_list_changes = HashSet::new(); + let mut device_list_left = HashSet::new(); + + let mut receipts = sync_events::v4::Receipts { + rooms: BTreeMap::new(), + }; + + let mut account_data = sync_events::v4::AccountData { + global: Vec::new(), + rooms: BTreeMap::new(), + }; + if body.extensions.account_data.enabled.unwrap_or(false) { + account_data.global = services + .account_data + .changes_since(None, sender_user, globalsince) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) + .collect(); + + if let Some(rooms) = body.extensions.account_data.rooms { + for room in rooms { + account_data.rooms.insert( + room.clone(), + services + .account_data + .changes_since(Some(&room), sender_user, globalsince) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) + .collect(), + ); + } + } + } + + if body.extensions.e2ee.enabled.unwrap_or(false) { + // Look for device list updates of this account + device_list_changes.extend( + services + .users + .keys_changed(sender_user.as_ref(), globalsince, None) + .map(ToOwned::to_owned) + .collect::>() + .await, + ); + + for room_id in &all_joined_rooms { + let Ok(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id).await else { + error!("Room {room_id} has no state"); + continue; + }; + + let since_shortstatehash = services + .rooms + .user + .get_token_shortstatehash(room_id, globalsince) + .await + .ok(); + + let encrypted_room = services + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); + + if let Some(since_shortstatehash) = since_shortstatehash { + // Skip if there are only timeline changes + if since_shortstatehash == current_shortstatehash { + continue; + } + + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; + + let since_sender_member: Option = services + .rooms + .state_accessor + .state_get_content(since_shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) + .ok() + .await; + + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + + let new_encrypted_room = encrypted_room && since_encryption.is_err(); + + if encrypted_room { + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + + let since_state_ids = services + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; + + for (key, id) in current_state_ids { + if since_state_ids.get(&key) != Some(&id) { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; + if pdu.kind == RoomMember { + if let Some(state_key) = &pdu.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + + if user_id == *sender_user { + continue; + } + + let content: RoomMemberEventContent = pdu.get_content()?; + match content.membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(&services, sender_user, &user_id, Some(room_id)) + .await + { + device_list_changes.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, + } + } + } + } + } + if joined_since_last_sync || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_changes.extend( + services + .rooms + .state_cache + .room_members(room_id) + // Don't send key updates from the sender to the sender + .ready_filter(|user_id| sender_user != user_id) + // Only send keys if the sender doesn't share an encrypted room with the target + // already + .filter_map(|user_id| { + share_encrypted_room(&services, sender_user, user_id, Some(room_id)) + .map(|res| res.or_some(user_id.to_owned())) + }) + .collect::>() + .await, + ); + } + } + } + // Look for device list updates in this room + device_list_changes.extend( + services + .users + .keys_changed(room_id.as_ref(), globalsince, None) + .map(ToOwned::to_owned) + .collect::>() + .await, + ); + } + + for user_id in left_encrypted_users { + let dont_share_encrypted_room = !share_encrypted_room(&services, sender_user, &user_id, None).await; + + // If the user doesn't share an encrypted room with the target anymore, we need + // to tell them + if dont_share_encrypted_room { + device_list_left.insert(user_id); + } + } + } + + let mut lists = BTreeMap::new(); + let mut todo_rooms = BTreeMap::new(); // and required state + + for (list_id, list) in &body.lists { + let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) { + Some(true) => &all_invited_rooms, + Some(false) => &all_joined_rooms, + None => &all_rooms, + }; + + let active_rooms = match list.filters.clone().map(|f| f.not_room_types) { + Some(filter) if filter.is_empty() => active_rooms.clone(), + Some(value) => filter_rooms(&services, active_rooms, &value, true).await, + None => active_rooms.clone(), + }; + + let active_rooms = match list.filters.clone().map(|f| f.room_types) { + Some(filter) if filter.is_empty() => active_rooms.clone(), + Some(value) => filter_rooms(&services, &active_rooms, &value, false).await, + None => active_rooms, + }; + + let mut new_known_rooms = BTreeSet::new(); + + let ranges = list.ranges.clone(); + lists.insert( + list_id.clone(), + sync_events::v4::SyncList { + ops: ranges + .into_iter() + .map(|mut r| { + r.0 = r.0.clamp( + uint!(0), + UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX), + ); + r.1 = + r.1.clamp(r.0, UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX)); + + let room_ids = if !active_rooms.is_empty() { + active_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec() + } else { + Vec::new() + }; + + new_known_rooms.extend(room_ids.iter().cloned()); + for room_id in &room_ids { + let todo_room = todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0, u64::MAX)); + + let limit = list + .room_details + .timeline_limit + .map_or(10, u64::from) + .min(100); + + todo_room + .0 + .extend(list.room_details.required_state.iter().cloned()); + + todo_room.1 = todo_room.1.max(limit); + // 0 means unknown because it got out of date + todo_room.2 = todo_room.2.min( + known_rooms + .get(list_id.as_str()) + .and_then(|k| k.get(room_id)) + .copied() + .unwrap_or(0), + ); + } + sync_events::v4::SyncOp { + op: SlidingOp::Sync, + range: Some(r), + index: None, + room_ids, + room_id: None, + } + }) + .collect(), + count: ruma_from_usize(active_rooms.len()), + }, + ); + + if let Some(conn_id) = &body.conn_id { + services.sync.update_sync_known_rooms( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + list_id.clone(), + new_known_rooms, + globalsince, + ); + } + } + + let mut known_subscription_rooms = BTreeSet::new(); + for (room_id, room) in &body.room_subscriptions { + if !services.rooms.metadata.exists(room_id).await { + continue; + } + let todo_room = todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0, u64::MAX)); + let limit = room.timeline_limit.map_or(10, u64::from).min(100); + todo_room.0.extend(room.required_state.iter().cloned()); + todo_room.1 = todo_room.1.max(limit); + // 0 means unknown because it got out of date + todo_room.2 = todo_room.2.min( + known_rooms + .get("subscriptions") + .and_then(|k| k.get(room_id)) + .copied() + .unwrap_or(0), + ); + known_subscription_rooms.insert(room_id.clone()); + } + + for r in body.unsubscribe_rooms { + known_subscription_rooms.remove(&r); + body.room_subscriptions.remove(&r); + } + + if let Some(conn_id) = &body.conn_id { + services.sync.update_sync_known_rooms( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + "subscriptions".to_owned(), + known_subscription_rooms, + globalsince, + ); + } + + if let Some(conn_id) = &body.conn_id { + services.sync.update_sync_subscriptions( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + body.room_subscriptions, + ); + } + + let mut rooms = BTreeMap::new(); + for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { + let roomsincecount = PduCount::Normal(*roomsince); + + let mut timestamp: Option<_> = None; + let mut invite_state = None; + let (timeline_pdus, limited); + if all_invited_rooms.contains(room_id) { + // TODO: figure out a timestamp we can use for remote invites + invite_state = services + .rooms + .state_cache + .invite_state(sender_user, room_id) + .await + .ok(); + + (timeline_pdus, limited) = (Vec::new(), true); + } else { + (timeline_pdus, limited) = + match load_timeline(&services, sender_user, room_id, roomsincecount, *timeline_limit).await { + Ok(value) => value, + Err(err) => { + warn!("Encountered missing timeline in {}, error {}", room_id, err); + continue; + }, + }; + } + + account_data.rooms.insert( + room_id.clone(), + services + .account_data + .changes_since(Some(room_id), sender_user, *roomsince) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) + .collect(), + ); + + let vector: Vec<_> = services + .rooms + .read_receipt + .readreceipts_since(room_id, *roomsince) + .filter_map(|(read_user, ts, v)| async move { + (!services + .users + .user_is_ignored(&read_user, sender_user) + .await) + .then_some((read_user, ts, v)) + }) + .collect() + .await; + + let receipt_size = vector.len(); + receipts + .rooms + .insert(room_id.clone(), pack_receipts(Box::new(vector.into_iter()))); + + if roomsince != &0 + && timeline_pdus.is_empty() + && account_data.rooms.get(room_id).is_some_and(Vec::is_empty) + && receipt_size == 0 + { + continue; + } + + let prev_batch = timeline_pdus + .first() + .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { + Ok(Some(match pdu_count { + PduCount::Backfilled(_) => { + error!("timeline in backfill state?!"); + "0".to_owned() + }, + PduCount::Normal(c) => c.to_string(), + })) + })? + .or_else(|| { + if roomsince != &0 { + Some(roomsince.to_string()) + } else { + None + } + }); + + let room_events: Vec<_> = timeline_pdus + .iter() + .stream() + .filter_map(|(_, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + + Some(pdu.to_sync_room_event()) + }) + .collect() + .await; + + for (_, pdu) in timeline_pdus { + let ts = MilliSecondsSinceUnixEpoch(pdu.origin_server_ts); + if DEFAULT_BUMP_TYPES.contains(pdu.event_type()) && !timestamp.is_some_and(|time| time > ts) { + timestamp = Some(ts); + } + } + + let required_state = required_state_request + .iter() + .stream() + .filter_map(|state| async move { + services + .rooms + .state_accessor + .room_state_get(room_id, &state.0, &state.1) + .await + .map(|s| s.to_sync_state_event()) + .ok() + }) + .collect() + .await; + + // Heroes + let heroes: Vec<_> = services + .rooms + .state_cache + .room_members(room_id) + .ready_filter(|member| member != sender_user) + .filter_map(|user_id| { + services + .rooms + .state_accessor + .get_member(room_id, user_id) + .map_ok(|memberevent| SlidingSyncRoomHero { + user_id: user_id.into(), + name: memberevent.displayname, + avatar: memberevent.avatar_url, + }) + .ok() + }) + .take(5) + .collect() + .await; + + let name = match heroes.len().cmp(&(1_usize)) { + Ordering::Greater => { + let firsts = heroes[1..] + .iter() + .map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string())) + .collect::>() + .join(", "); + + let last = heroes[0] + .name + .clone() + .unwrap_or_else(|| heroes[0].user_id.to_string()); + + Some(format!("{firsts} and {last}")) + }, + Ordering::Equal => Some( + heroes[0] + .name + .clone() + .unwrap_or_else(|| heroes[0].user_id.to_string()), + ), + Ordering::Less => None, + }; + + let heroes_avatar = if heroes.len() == 1 { + heroes[0].avatar.clone() + } else { + None + }; + + rooms.insert( + room_id.clone(), + sync_events::v4::SlidingSyncRoom { + name: services + .rooms + .state_accessor + .get_name(room_id) + .await + .ok() + .or(name), + avatar: if let Some(heroes_avatar) = heroes_avatar { + ruma::JsOption::Some(heroes_avatar) + } else { + match services.rooms.state_accessor.get_avatar(room_id).await { + ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), + ruma::JsOption::Null => ruma::JsOption::Null, + ruma::JsOption::Undefined => ruma::JsOption::Undefined, + } + }, + initial: Some(roomsince == &0), + is_dm: None, + invite_state, + unread_notifications: UnreadNotificationsCount { + highlight_count: Some( + services + .rooms + .user + .highlight_count(sender_user, room_id) + .await + .try_into() + .expect("notification count can't go that high"), + ), + notification_count: Some( + services + .rooms + .user + .notification_count(sender_user, room_id) + .await + .try_into() + .expect("notification count can't go that high"), + ), + }, + timeline: room_events, + required_state, + prev_batch, + limited, + joined_count: Some( + services + .rooms + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), + ), + invited_count: Some( + services + .rooms + .state_cache + .room_invited_count(room_id) + .await + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), + ), + num_live: None, // Count events in timeline greater than global sync counter + timestamp, + heroes: Some(heroes), + }, + ); + } + + if rooms + .iter() + .all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) + { + // Hang a few seconds so requests are not spammed + // Stop hanging if new info arrives + let default = Duration::from_secs(30); + let duration = cmp::min(body.timeout.unwrap_or(default), default); + _ = tokio::time::timeout(duration, watcher).await; + } + + Ok(sync_events::v4::Response { + initial: globalsince == 0, + txn_id: body.txn_id.clone(), + pos: next_batch.to_string(), + lists, + rooms, + extensions: sync_events::v4::Extensions { + to_device: if body.extensions.to_device.enabled.unwrap_or(false) { + Some(sync_events::v4::ToDevice { + events: services + .users + .get_to_device_events(sender_user, &sender_device) + .collect() + .await, + next_batch: next_batch.to_string(), + }) + } else { + None + }, + e2ee: sync_events::v4::E2EE { + device_lists: DeviceLists { + changed: device_list_changes.into_iter().collect(), + left: device_list_left.into_iter().collect(), + }, + device_one_time_keys_count: services + .users + .count_one_time_keys(sender_user, &sender_device) + .await, + // Fallback keys are not yet supported + device_unused_fallback_key_types: None, + }, + account_data, + receipts, + typing: sync_events::v4::Typing { + rooms: BTreeMap::new(), + }, + }, + delta_token: None, + }) +} + +async fn filter_rooms( + services: &Services, rooms: &[OwnedRoomId], filter: &[RoomTypeFilter], negate: bool, +) -> Vec { + rooms + .iter() + .stream() + .filter_map(|r| async move { + let room_type = services.rooms.state_accessor.get_room_type(r).await; + + if room_type.as_ref().is_err_and(|e| !e.is_not_found()) { + return None; + } + + let room_type_filter = RoomTypeFilter::from(room_type.ok()); + + let include = if negate { + !filter.contains(&room_type_filter) + } else { + filter.is_empty() || filter.contains(&room_type_filter) + }; + + include.then_some(r.to_owned()) + }) + .collect() + .await +} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 3adecc6c1..96a98537c 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -40,6 +40,16 @@ pub use self::{ #[inline] pub fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } +#[macro_export] +macro_rules! extract_variant { + ($e:expr, $variant:path) => { + match $e { + $variant(value) => Some(value), + _ => None, + } + }; +} + #[macro_export] macro_rules! at { ($idx:tt) => { From 1fdcab0319f2461ae000a0876823848f9e5af921 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 16 Oct 2024 06:58:37 +0000 Subject: [PATCH 092/245] additional sync cleanup Signed-off-by: Jason Volk --- src/api/client/sync/mod.rs | 51 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index 0cfc7b8b3..ed22010c9 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -14,38 +14,37 @@ use crate::{service::Services, Error, PduEvent, Result}; async fn load_timeline( services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { - let timeline_pdus; - let limited = if services + let last_timeline_count = services .rooms .timeline .last_timeline_count(sender_user, room_id) + .await?; + + if last_timeline_count <= roomsincecount { + return Ok((Vec::new(), false)); + } + + let mut non_timeline_pdus = services + .rooms + .timeline + .pdus_until(sender_user, room_id, PduCount::max()) .await? - > roomsincecount - { - let mut non_timeline_pdus = services - .rooms - .timeline - .pdus_until(sender_user, room_id, PduCount::max()) - .await? - .ready_take_while(|(pducount, _)| pducount > &roomsincecount); + .ready_take_while(|(pducount, _)| pducount > &roomsincecount); + + // Take the last events for the timeline + let timeline_pdus: Vec<_> = non_timeline_pdus + .by_ref() + .take(usize_from_u64_truncated(limit)) + .collect::>() + .await + .into_iter() + .rev() + .collect(); - // Take the last events for the timeline - timeline_pdus = non_timeline_pdus - .by_ref() - .take(usize_from_u64_truncated(limit)) - .collect::>() - .await - .into_iter() - .rev() - .collect::>(); + // They /sync response doesn't always return all messages, so we say the output + // is limited unless there are events in non_timeline_pdus + let limited = non_timeline_pdus.next().await.is_some(); - // They /sync response doesn't always return all messages, so we say the output - // is limited unless there are events in non_timeline_pdus - non_timeline_pdus.next().await.is_some() - } else { - timeline_pdus = Vec::new(); - false - }; Ok((timeline_pdus, limited)) } From 93130fbb85e4f5645d1b98aeddd8e1828b1b936c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 21 Oct 2024 20:21:00 +0000 Subject: [PATCH 093/245] add is_ok to futures TryExtExt utils Signed-off-by: Jason Volk --- src/core/utils/future/try_ext_ext.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/core/utils/future/try_ext_ext.rs b/src/core/utils/future/try_ext_ext.rs index d30d2cac7..7c0b36a28 100644 --- a/src/core/utils/future/try_ext_ext.rs +++ b/src/core/utils/future/try_ext_ext.rs @@ -10,6 +10,17 @@ pub trait TryExtExt where Self: TryFuture + Send, { + /// Resolves to a bool for whether the TryFuture (Future of a Result) + /// resolved to Ok or Err. + /// + /// is_ok() has to consume *self rather than borrow. The intent of this + /// extension is therefor for a caller only ever caring about result status + /// while discarding all contents. + #[allow(clippy::wrong_self_convention)] + fn is_ok(self) -> MapOkOrElse bool, impl FnOnce(Self::Error) -> bool> + where + Self: Sized; + fn map_ok_or( self, default: U, f: F, ) -> MapOkOrElse U, impl FnOnce(Self::Error) -> U> @@ -32,6 +43,14 @@ impl TryExtExt for Fut where Fut: TryFuture + Send, { + #[inline] + fn is_ok(self) -> MapOkOrElse bool, impl FnOnce(Self::Error) -> bool> + where + Self: Sized, + { + self.map_ok_or(false, |_| true) + } + #[inline] fn map_ok_or( self, default: U, f: F, From ac75ebee8afd9874d02a2c33b86ab388e7ab289b Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 16 Oct 2024 11:33:24 +0000 Subject: [PATCH 094/245] event_handler/timeline service cleanups Signed-off-by: Jason Volk --- src/service/rooms/event_handler/mod.rs | 63 ++++++++++++++------------ src/service/rooms/timeline/data.rs | 29 ++++-------- 2 files changed, 43 insertions(+), 49 deletions(-) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 0ffd9659b..41ab79f11 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -169,7 +169,7 @@ impl Service { .await?; // Procure the room version - let room_version_id = Self::get_room_version_id(&create_event)?; + let room_version_id = get_room_version_id(&create_event)?; let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; @@ -178,7 +178,7 @@ impl Service { .boxed() .await?; - Self::check_room_id(room_id, &incoming_pdu)?; + check_room_id(room_id, &incoming_pdu)?; // 8. if not timeline event: stop if !is_timeline_event { @@ -341,7 +341,7 @@ impl Service { // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match - let room_version_id = Self::get_room_version_id(create_event)?; + let room_version_id = get_room_version_id(create_event)?; let mut val = match self .services .server_keys @@ -378,7 +378,7 @@ impl Service { ) .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - Self::check_room_id(room_id, &incoming_pdu)?; + check_room_id(room_id, &incoming_pdu)?; if !auth_events_known { // 4. fetch any missing auth events doing all checks listed here starting at 1. @@ -414,7 +414,7 @@ impl Service { continue; }; - Self::check_room_id(room_id, &auth_event)?; + check_room_id(room_id, &auth_event)?; match auth_events.entry(( auth_event.kind.to_string().into(), @@ -454,7 +454,7 @@ impl Service { }; let auth_check = state_res::event_auth::auth_check( - &Self::to_room_version(&room_version_id), + &to_room_version(&room_version_id), &incoming_pdu, None, // TODO: third party invite state_fetch, @@ -502,8 +502,8 @@ impl Service { } debug!("Upgrading to timeline pdu"); - let timer = tokio::time::Instant::now(); - let room_version_id = Self::get_room_version_id(create_event)?; + let timer = Instant::now(); + let room_version_id = get_room_version_id(create_event)?; // 10. Fetch missing state and auth chain events by calling /state_ids at // backwards extremities doing all the checks in this list starting at 1. @@ -524,7 +524,7 @@ impl Service { } let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); - let room_version = Self::to_room_version(&room_version_id); + let room_version = to_room_version(&room_version_id); debug!("Performing auth check"); // 11. Check the auth of the event passes based on the state of the event @@ -1278,7 +1278,7 @@ impl Service { .await .pop() { - Self::check_room_id(room_id, &pdu)?; + check_room_id(room_id, &pdu)?; let limit = self.services.globals.max_fetch_prev_events(); if amount > limit { @@ -1370,31 +1370,34 @@ impl Service { } } - fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { - if pdu.room_id != room_id { - return Err!(Request(InvalidParam( - warn!(pdu_event_id = ?pdu.event_id, pdu_room_id = ?pdu.room_id, ?room_id, "Found event from room in room") - ))); - } + async fn event_exists(&self, event_id: Arc) -> bool { self.services.timeline.pdu_exists(&event_id).await } - Ok(()) + async fn event_fetch(&self, event_id: Arc) -> Option> { + self.services.timeline.get_pdu(&event_id).await.ok() } +} - fn get_room_version_id(create_event: &PduEvent) -> Result { - let content: RoomCreateEventContent = create_event.get_content()?; - let room_version = content.room_version; - - Ok(room_version) +fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result { + if pdu.room_id != room_id { + return Err!(Request(InvalidParam(error!( + pdu_event_id = ?pdu.event_id, + pdu_room_id = ?pdu.room_id, + ?room_id, + "Found event from room in room", + )))); } - #[inline] - fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { - RoomVersion::new(room_version_id).expect("room version is supported") - } + Ok(()) +} - async fn event_exists(&self, event_id: Arc) -> bool { self.services.timeline.pdu_exists(&event_id).await } +fn get_room_version_id(create_event: &PduEvent) -> Result { + let content: RoomCreateEventContent = create_event.get_content()?; + let room_version = content.room_version; - async fn event_fetch(&self, event_id: Arc) -> Option> { - self.services.timeline.get_pdu(&event_id).await.ok() - } + Ok(room_version) +} + +#[inline] +fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { + RoomVersion::new(room_version_id).expect("room version is supported") } diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index c51b78568..5428a3b9d 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -8,11 +8,11 @@ use conduit::{ err, expected, result::{LogErr, NotFound}, utils, - utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, + utils::{future::TryExtExt, stream::TryIgnore, u64_from_u8, ReadyExt}, Err, PduCount, PduEvent, Result, }; use database::{Database, Deserialized, Json, KeyVal, Map}; -use futures::{FutureExt, Stream, StreamExt}; +use futures::{Stream, StreamExt}; use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; @@ -115,12 +115,10 @@ impl Data { /// Like get_non_outlier_pdu(), but without the expense of fetching and /// parsing the PduEvent - pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.get(&pduid).await?; - - Ok(()) + self.pduid_pdu.get(&pduid).await.map(|_| ()) } /// Returns the pdu. @@ -140,16 +138,14 @@ impl Data { /// Like get_non_outlier_pdu(), but without the expense of fetching and /// parsing the PduEvent - pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { - self.eventid_outlierpdu.get(event_id).await?; - - Ok(()) + pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result { + self.eventid_outlierpdu.get(event_id).await.map(|_| ()) } /// Like get_pdu(), but without the expense of fetching and parsing the data pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool { - let non_outlier = self.non_outlier_pdu_exists(event_id).map(|res| res.is_ok()); - let outlier = self.outlier_pdu_exists(event_id).map(|res| res.is_ok()); + let non_outlier = self.non_outlier_pdu_exists(event_id).is_ok(); + let outlier = self.outlier_pdu_exists(event_id).is_ok(); //TODO: parallelize non_outlier.await || outlier.await @@ -169,7 +165,6 @@ impl Data { pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { self.pduid_pdu.raw_put(pdu_id, Json(json)); - self.lasttimelinecount_cache .lock() .await @@ -181,21 +176,17 @@ impl Data { pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { self.pduid_pdu.raw_put(pdu_id, Json(json)); - self.eventid_pduid.insert(event_id, pdu_id); self.eventid_outlierpdu.remove(event_id); } /// Removes a pdu and creates a new one with the same id. - pub(super) async fn replace_pdu( - &self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, - ) -> Result<()> { + pub(super) async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result { if self.pduid_pdu.get(pdu_id).await.is_not_found() { return Err!(Request(NotFound("PDU does not exist."))); } - let pdu = serde_json::to_vec(pdu_json)?; - self.pduid_pdu.insert(pdu_id, &pdu); + self.pduid_pdu.raw_put(pdu_id, Json(pdu_json)); Ok(()) } From b505f0d0d7a8ec2accc4b38dfe3391c9f780ba25 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 21 Oct 2024 22:00:39 +0000 Subject: [PATCH 095/245] add (back) query_trusted_key_servers_first w/ additional configuration detail Signed-off-by: Jason Volk --- src/core/config/mod.rs | 29 +++++++++++++++ src/service/server_keys/acquire.rs | 59 +++++++++++++++++++++++------- src/service/server_keys/get.rs | 47 ++++++++++++++++++++---- src/service/server_keys/mod.rs | 4 +- 4 files changed, 116 insertions(+), 23 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 02b277d0b..52ce8a016 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -490,6 +490,35 @@ pub struct Config { #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec, + /// Whether to query the servers listed in trusted_servers first or query + /// the origin server first. For best security, querying the origin server + /// first is advised to minimize the exposure to a compromised trusted + /// server. For maximum performance this can be set to true, however other + /// options exist to query trusted servers first under specific high-load + /// circumstances and should be evaluated before setting this to true. + #[serde(default)] + pub query_trusted_key_servers_first: bool, + + /// Whether to query the servers listed in trusted_servers first + /// specifically on room joins. This option limits the exposure to a + /// compromised trusted server to room joins only. The join operation + /// requires gathering keys from many origin servers which can cause + /// significant delays. Therefor this defaults to true to mitigate + /// unexpected delays out-of-the-box. The security-paranoid or those + /// willing to tolerate delays are advised to set this to false. Note that + /// setting query_trusted_key_servers_first to true causes this option to + /// be ignored. + #[serde(default = "true_fn")] + pub query_trusted_key_servers_first_on_join: bool, + + /// Only query trusted servers for keys and never the origin server. This is + /// intended for clusters or custom deployments using their trusted_servers + /// as forwarding-agents to cache and deduplicate requests. Notary servers + /// do not act as forwarding-agents by default, therefor do not enable this + /// unless you know exactly what you are doing. + #[serde(default)] + pub only_query_trusted_key_servers: bool, + /// max log level for conduwuit. allows debug, info, warn, or error /// see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives /// **Caveat**: diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs index 2b1700400..25b676b8f 100644 --- a/src/service/server_keys/acquire.rs +++ b/src/service/server_keys/acquire.rs @@ -47,35 +47,66 @@ where S: Iterator + Send + Clone, K: Iterator + Send + Clone, { + let notary_only = self.services.server.config.only_query_trusted_key_servers; + let notary_first_always = self.services.server.config.query_trusted_key_servers_first; + let notary_first_on_join = self + .services + .server + .config + .query_trusted_key_servers_first_on_join; + let requested_servers = batch.clone().count(); let requested_keys = batch.clone().flat_map(|(_, key_ids)| key_ids).count(); debug!("acquire {requested_keys} keys from {requested_servers}"); - let missing = self.acquire_locals(batch).await; - let missing_keys = keys_count(&missing); - let missing_servers = missing.len(); + let mut missing = self.acquire_locals(batch).await; + let mut missing_keys = keys_count(&missing); + let mut missing_servers = missing.len(); if missing_servers == 0 { return; } debug!("missing {missing_keys} keys for {missing_servers} servers locally"); - let missing = self.acquire_origins(missing.into_iter()).await; - let missing_keys = keys_count(&missing); - let missing_servers = missing.len(); - if missing_servers == 0 { - return; + if notary_first_always || notary_first_on_join { + missing = self.acquire_notary(missing.into_iter()).await; + missing_keys = keys_count(&missing); + missing_servers = missing.len(); + if missing_keys == 0 { + return; + } + + debug_warn!("missing {missing_keys} keys for {missing_servers} servers from all notaries first"); + } + + if !notary_only { + missing = self.acquire_origins(missing.into_iter()).await; + missing_keys = keys_count(&missing); + missing_servers = missing.len(); + if missing_keys == 0 { + return; + } + + debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable"); } - debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable"); + if !notary_first_always && !notary_first_on_join { + missing = self.acquire_notary(missing.into_iter()).await; + missing_keys = keys_count(&missing); + missing_servers = missing.len(); + if missing_keys == 0 { + return; + } + + debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries."); + } - let missing = self.acquire_notary(missing.into_iter()).await; - let missing_keys = keys_count(&missing); - let missing_servers = missing.len(); if missing_keys > 0 { - debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries"); - warn!("did not obtain {missing_keys} of {requested_keys} keys; some events may not be accepted"); + warn!( + "did not obtain {missing_keys} keys for {missing_servers} servers out of {requested_keys} total keys for \ + {requested_servers} total servers; some events may not be verifiable" + ); } } diff --git a/src/service/server_keys/get.rs b/src/service/server_keys/get.rs index 0f449b46b..441e33d45 100644 --- a/src/service/server_keys/get.rs +++ b/src/service/server_keys/get.rs @@ -53,17 +53,40 @@ where #[implement(super::Service)] pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result { + let notary_first = self.services.server.config.query_trusted_key_servers_first; + let notary_only = self.services.server.config.only_query_trusted_key_servers; + if let Some(result) = self.verify_keys_for(origin).await.remove(key_id) { return Ok(result); } - if let Ok(server_key) = self.server_request(origin).await { - self.add_signing_keys(server_key.clone()).await; - if let Some(result) = extract_key(server_key, key_id) { + if notary_first { + if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await { + return Ok(result); + } + } + + if !notary_only { + if let Ok(result) = self.get_verify_key_from_origin(origin, key_id).await { + return Ok(result); + } + } + + if !notary_first { + if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await { return Ok(result); } } + Err!(BadServerResponse(debug_error!( + ?key_id, + ?origin, + "Failed to fetch federation signing-key" + ))) +} + +#[implement(super::Service)] +async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result { for notary in self.services.globals.trusted_servers() { if let Ok(server_keys) = self.notary_request(notary, origin).await { for server_key in &server_keys { @@ -78,9 +101,17 @@ pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKe } } - Err!(BadServerResponse(debug_error!( - ?key_id, - ?origin, - "Failed to fetch federation signing-key" - ))) + Err!(Request(NotFound("Failed to fetch signing-key from notaries"))) +} + +#[implement(super::Service)] +async fn get_verify_key_from_origin(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result { + if let Ok(server_key) = self.server_request(origin).await { + self.add_signing_keys(server_key.clone()).await; + if let Some(result) = extract_key(server_key, key_id) { + return Ok(result); + } + } + + Err!(Request(NotFound("Failed to fetch signing-key from origin"))) } diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index c3b84cb33..dc09703ce 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -7,7 +7,7 @@ mod verify; use std::{collections::BTreeMap, sync::Arc, time::Duration}; -use conduit::{implement, utils::time::timepoint_from_now, Result}; +use conduit::{implement, utils::time::timepoint_from_now, Result, Server}; use database::{Deserialized, Json, Map}; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, @@ -30,6 +30,7 @@ pub struct Service { struct Services { globals: Dep, sending: Dep, + server: Arc, } struct Data { @@ -52,6 +53,7 @@ impl crate::Service for Service { services: Services { globals: args.depend::("globals"), sending: args.depend::("sending"), + server: args.server.clone(), }, db: Data { server_signingkeys: args.db["server_signingkeys"].clone(), From 0e55fa2de24a945e55469ec496f85e29e5f10d5b Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 21 Oct 2024 23:54:54 +0000 Subject: [PATCH 096/245] add ready_try_for_each to TryReadyExt extension utils Signed-off-by: Jason Volk --- src/core/utils/stream/try_ready.rs | 18 +++++++++++++++++- src/service/sending/mod.rs | 14 +++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs index ab37d9b30..df3564565 100644 --- a/src/core/utils/stream/try_ready.rs +++ b/src/core/utils/stream/try_ready.rs @@ -2,7 +2,7 @@ use futures::{ future::{ready, Ready}, - stream::{AndThen, TryStream, TryStreamExt}, + stream::{AndThen, TryForEach, TryStream, TryStreamExt}, }; use crate::Result; @@ -18,6 +18,12 @@ where fn ready_and_then(self, f: F) -> AndThen>, impl FnMut(S::Ok) -> Ready>> where F: Fn(S::Ok) -> Result; + + fn ready_try_for_each( + self, f: F, + ) -> TryForEach>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result<(), E>; } impl TryReadyExt for S @@ -32,4 +38,14 @@ where { self.and_then(move |t| ready(f(t))) } + + #[inline] + fn ready_try_for_each( + self, f: F, + ) -> TryForEach>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result<(), E>, + { + self.try_for_each(move |t| ready(f(t))) + } } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 63c5e655a..a1d5f6922 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -7,8 +7,12 @@ mod sender; use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; -use conduit::{err, utils::ReadyExt, warn, Result, Server}; -use futures::{future::ready, Stream, StreamExt, TryStreamExt}; +use conduit::{ + err, + utils::{ReadyExt, TryReadyExt}, + warn, Result, Server, +}; +use futures::{Stream, StreamExt}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, RoomId, ServerName, UserId, @@ -235,12 +239,12 @@ impl Service { .map(ToOwned::to_owned) .map(Destination::Normal) .map(Ok) - .try_for_each(|dest| { - ready(self.dispatch(Msg { + .ready_try_for_each(|dest| { + self.dispatch(Msg { dest, event: SendingEvent::Flush, queue_id: Vec::::new(), - })) + }) }) .await } From 167807e0a6e333a4a8f7be9b8ed0da46831ce234 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 00:09:55 +0000 Subject: [PATCH 097/245] de-wrapper max_fetch_prev_event; increase default config Signed-off-by: Jason Volk --- src/core/config/mod.rs | 2 +- src/service/globals/mod.rs | 2 -- src/service/rooms/event_handler/mod.rs | 6 ++++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 52ce8a016..23d35424a 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1822,7 +1822,7 @@ fn default_appservice_idle_timeout() -> u64 { 300 } fn default_pusher_idle_timeout() -> u64 { 15 } -fn default_max_fetch_prev_events() -> u16 { 100_u16 } +fn default_max_fetch_prev_events() -> u16 { 192_u16 } fn default_tracing_flame_filter() -> String { cfg!(debug_assertions) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 7680007d4..329a6583c 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -171,8 +171,6 @@ impl Service { #[inline] pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } - pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } - pub fn allow_registration(&self) -> bool { self.config.allow_registration } pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 41ab79f11..8f96f68e5 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -13,7 +13,7 @@ use conduit::{ result::LogErr, trace, utils::{math::continue_exponential_backoff_secs, IterStream, MutexMap}, - warn, Err, Error, PduEvent, Result, + warn, Err, Error, PduEvent, Result, Server, }; use futures::{future, future::ready, FutureExt, StreamExt, TryFutureExt}; use ruma::{ @@ -55,6 +55,7 @@ struct Services { state_accessor: Dep, state_compressor: Dep, timeline: Dep, + server: Arc, } type RoomMutexMap = MutexMap; @@ -76,6 +77,7 @@ impl crate::Service for Service { state_accessor: args.depend::("rooms::state_accessor"), state_compressor: args.depend::("rooms::state_compressor"), timeline: args.depend::("rooms::timeline"), + server: args.server.clone(), }, federation_handletime: HandleTimeMap::new().into(), mutex_federation: RoomMutexMap::new(), @@ -1280,7 +1282,7 @@ impl Service { { check_room_id(room_id, &pdu)?; - let limit = self.services.globals.max_fetch_prev_events(); + let limit = self.services.server.config.max_fetch_prev_events; if amount > limit { debug_warn!("Max prev event limit reached! Limit: {limit}"); graph.insert(prev_event_id.clone(), HashSet::new()); From c06f560913ce637419f9a825ca8b6ffaca698bc8 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 03:21:56 +0000 Subject: [PATCH 098/245] add some additional database::de test cases Signed-off-by: Jason Volk --- src/database/tests.rs | 62 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/src/database/tests.rs b/src/database/tests.rs index 47dfb32c3..bfab99ef0 100644 --- a/src/database/tests.rs +++ b/src/database/tests.rs @@ -10,7 +10,7 @@ use serde::Serialize; use crate::{ de, ser, ser::{serialize_to_vec, Json}, - Interfix, + Ignore, Interfix, }; #[test] @@ -187,6 +187,66 @@ fn de_tuple() { assert_eq!(b, room_id, "deserialized room_id does not match"); } +#[test] +#[should_panic(expected = "failed to deserialize")] +fn de_tuple_invalid() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF@user:example.com"; + let (a, b): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, room_id, "deserialized room_id does not match"); +} + +#[test] +#[should_panic(expected = "failed to deserialize")] +fn de_tuple_incomplete() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com"; + let (a, _): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); +} + +#[test] +#[should_panic(expected = "failed to deserialize")] +fn de_tuple_incomplete_with_sep() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF"; + let (a, _): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); +} + +#[test] +#[should_panic(expected = "deserialization failed to consume trailing bytes")] +fn de_tuple_unfinished() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF!room:example.com\xFF@user:example.com"; + let (a, b): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, room_id, "deserialized room_id does not match"); +} + +#[test] +fn de_tuple_ignore() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF@user2:example.net\xFF!room:example.com"; + let (a, _, c): (&UserId, Ignore, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(c, room_id, "deserialized room_id does not match"); +} + #[test] fn de_json_array() { let a = &["foo", "bar", "baz"]; From 0e0438e1f9b49a3fa1b8fc0dece769d91c2bafbf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 03:28:45 +0000 Subject: [PATCH 099/245] further optimize presence_since iteration Signed-off-by: Jason Volk --- src/admin/query/presence.rs | 10 +++++++--- src/api/client/sync/v3.rs | 8 ++++---- src/service/presence/data.rs | 11 ++++++----- src/service/presence/mod.rs | 3 +-- src/service/sending/sender.rs | 8 +++++--- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/admin/query/presence.rs b/src/admin/query/presence.rs index 6189270cc..0963429e8 100644 --- a/src/admin/query/presence.rs +++ b/src/admin/query/presence.rs @@ -42,12 +42,16 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) since, } => { let timer = tokio::time::Instant::now(); - let results = services.presence.db.presence_since(since); - let presence_since: Vec<(_, _, _)> = results.collect().await; + let results: Vec<(_, _, _)> = services + .presence + .presence_since(since) + .map(|(user_id, count, bytes)| (user_id.to_owned(), count, bytes.to_vec())) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{presence_since:#?}\n```" + "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, } diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index f29fe220e..2bd318df4 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -488,7 +488,7 @@ async fn process_presence_updates( if !services .rooms .state_cache - .user_sees_user(syncing_user, &user_id) + .user_sees_user(syncing_user, user_id) .await { continue; @@ -496,10 +496,10 @@ async fn process_presence_updates( let presence_event = services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id) + .from_json_bytes_to_event(presence_bytes, user_id) .await?; - match presence_updates.entry(user_id) { + match presence_updates.entry(user_id.into()) { Entry::Vacant(slot) => { slot.insert(presence_event); }, @@ -524,7 +524,7 @@ async fn process_presence_updates( .currently_active .or(curr_content.currently_active); }, - } + }; } Ok(()) diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index 8522746fd..68b2c3feb 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -7,7 +7,7 @@ use conduit::{ }; use database::{Deserialized, Json, Map}; use futures::Stream; -use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; +use ruma::{events::presence::PresenceEvent, presence::PresenceState, UInt, UserId}; use super::Presence; use crate::{globals, users, Dep}; @@ -137,13 +137,14 @@ impl Data { self.userid_presenceid.remove(user_id); } - pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { + #[inline] + pub(super) fn presence_since(&self, since: u64) -> impl Stream + Send + '_ { self.presenceid_presence .raw_stream() .ignore_err() - .ready_filter_map(move |(key, presence_bytes)| { - let (count, user_id) = presenceid_parse(key).expect("invalid presenceid_parse"); - (count > since).then(|| (user_id.to_owned(), count, presence_bytes.to_vec())) + .ready_filter_map(move |(key, presence)| { + let (count, user_id) = presenceid_parse(key).ok()?; + (count > since).then_some((user_id, count, presence)) }) } } diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 82a99bd56..b2106f3f7 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -162,8 +162,7 @@ impl Service { /// Returns the most recent presence updates that happened after the event /// with id `since`. - #[inline] - pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { + pub fn presence_since(&self, since: u64) -> impl Stream + Send + '_ { self.db.presence_since(since) } diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 5c0a324bc..a57d4aeae 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -7,7 +7,9 @@ use std::{ use base64::{engine::general_purpose, Engine as _}; use conduit::{ - debug, debug_warn, err, trace, + debug, debug_warn, err, + result::LogErr, + trace, utils::{calculate_hash, math::continue_exponential_backoff_secs, ReadyExt}, warn, Error, Result, }; @@ -315,14 +317,14 @@ impl Service { while let Some((user_id, count, presence_bytes)) = presence_since.next().await { *max_edu_count = cmp::max(count, *max_edu_count); - if !self.services.globals.user_is_local(&user_id) { + if !self.services.globals.user_is_local(user_id) { continue; } if !self .services .state_cache - .server_sees_user(server_name, &user_id) + .server_sees_user(server_name, user_id) .await { continue; From a74461fc9a9bd8f5a237662b399431cacc3f29e6 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 04:03:07 +0000 Subject: [PATCH 100/245] split keys_changed for stronger-type overloads Signed-off-by: Jason Volk --- src/api/client/keys.rs | 5 +++-- src/api/client/sync/v3.rs | 5 +++-- src/api/client/sync/v4.rs | 5 +++-- src/api/mod.rs | 1 - src/service/users/mod.rs | 21 ++++++++++++++++++--- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 254d92ccd..44d9164c9 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -232,7 +232,7 @@ pub(crate) async fn get_key_changes_route( device_list_updates.extend( services .users - .keys_changed(sender_user.as_str(), from, Some(to)) + .keys_changed(sender_user, from, Some(to)) .map(ToOwned::to_owned) .collect::>() .await, @@ -244,7 +244,8 @@ pub(crate) async fn get_key_changes_route( device_list_updates.extend( services .users - .keys_changed(room_id.as_str(), from, Some(to)) + .room_keys_changed(room_id, from, Some(to)) + .map(|(user_id, _)| user_id) .map(ToOwned::to_owned) .collect::>() .await, diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 2bd318df4..ccca1f85d 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -138,7 +138,7 @@ pub(crate) async fn sync_events_route( device_list_updates.extend( services .users - .keys_changed(sender_user.as_ref(), since, None) + .keys_changed(&sender_user, since, None) .map(ToOwned::to_owned) .collect::>() .await, @@ -917,7 +917,8 @@ async fn load_joined_room( device_list_updates.extend( services .users - .keys_changed(room_id.as_ref(), since, None) + .room_keys_changed(room_id, since, None) + .map(|(user_id, _)| user_id) .map(ToOwned::to_owned) .collect::>() .await, diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 2adb3b71a..4f8323e66 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -162,7 +162,7 @@ pub(crate) async fn sync_events_v4_route( device_list_changes.extend( services .users - .keys_changed(sender_user.as_ref(), globalsince, None) + .keys_changed(sender_user, globalsince, None) .map(ToOwned::to_owned) .collect::>() .await, @@ -285,7 +285,8 @@ pub(crate) async fn sync_events_v4_route( device_list_changes.extend( services .users - .keys_changed(room_id.as_ref(), globalsince, None) + .room_keys_changed(room_id, globalsince, None) + .map(|(user_id, _)| user_id) .map(ToOwned::to_owned) .collect::>() .await, diff --git a/src/api/mod.rs b/src/api/mod.rs index 96837470b..ed8aacf23 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -6,7 +6,6 @@ extern crate conduit_core as conduit; extern crate conduit_service as service; pub(crate) use conduit::{debug_info, pdu::PduEvent, utils, Error, Result}; -pub(crate) use service::services; pub(crate) use self::router::{Ruma, RumaResponse, State}; diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 589aee8a1..b9183e128 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -13,7 +13,7 @@ use ruma::{ events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType}, serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, - OwnedMxcUri, OwnedUserId, UInt, UserId, + OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId, }; use serde_json::json; @@ -585,9 +585,24 @@ impl Service { Ok(()) } + #[inline] pub fn keys_changed<'a>( - &'a self, user_or_room_id: &'a str, from: u64, to: Option, + &'a self, user_id: &'a UserId, from: u64, to: Option, ) -> impl Stream + Send + 'a { + self.keys_changed_user_or_room(user_id.as_str(), from, to) + .map(|(user_id, ..)| user_id) + } + + #[inline] + pub fn room_keys_changed<'a>( + &'a self, room_id: &'a RoomId, from: u64, to: Option, + ) -> impl Stream + Send + 'a { + self.keys_changed_user_or_room(room_id.as_str(), from, to) + } + + fn keys_changed_user_or_room<'a>( + &'a self, user_or_room_id: &'a str, from: u64, to: Option, + ) -> impl Stream + Send + 'a { type KeyVal<'a> = ((&'a str, u64), &'a UserId); let to = to.unwrap_or(u64::MAX); @@ -597,7 +612,7 @@ impl Service { .stream_from(&start) .ignore_err() .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to) - .map(|((..), user_id): KeyVal<'_>| user_id) + .map(|((_, count), user_id): KeyVal<'_>| (user_id, count)) } pub async fn mark_device_key_update(&self, user_id: &UserId) { From d35376a90cb521b578aff75a1441699e10695bac Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 05:30:28 +0000 Subject: [PATCH 101/245] aggregate receipts into single edu; dedup presence; refactor selection limits etc Signed-off-by: Jason Volk --- src/api/server/send.rs | 8 +- src/service/sending/mod.rs | 5 +- src/service/sending/sender.rs | 301 ++++++++++++++++++++-------------- 3 files changed, 187 insertions(+), 127 deletions(-) diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 40f9403b2..e2100a0f5 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -21,16 +21,16 @@ use ruma::{ OwnedEventId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; +use service::{ + sending::{EDU_LIMIT, PDU_LIMIT}, + Services, +}; use crate::{ - services::Services, utils::{self}, Ruma, }; -const PDU_LIMIT: usize = 50; -const EDU_LIMIT: usize = 100; - type ResolvedMap = BTreeMap>; /// # `PUT /_matrix/federation/v1/send/{txnId}` diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index a1d5f6922..ea2668837 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -20,7 +20,10 @@ use ruma::{ use tokio::sync::Mutex; use self::data::Data; -pub use self::dest::Destination; +pub use self::{ + dest::Destination, + sender::{EDU_LIMIT, PDU_LIMIT}, +}; use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_keys, users, Dep}; pub struct Service { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index a57d4aeae..d9087d443 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -7,7 +7,7 @@ use std::{ use base64::{engine::general_purpose, Engine as _}; use conduit::{ - debug, debug_warn, err, + debug, debug_warn, err, error, result::LogErr, trace, utils::{calculate_hash, math::continue_exponential_backoff_secs, ReadyExt}, @@ -26,8 +26,8 @@ use ruma::{ }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, - push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, - ServerName, UInt, + push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, + RoomVersionId, ServerName, UInt, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::time::sleep_until; @@ -47,10 +47,16 @@ type SendingFuture<'a> = BoxFuture<'a, SendingResult>; type SendingFutures<'a> = FuturesUnordered>; type CurTransactionStatus = HashMap; -const DEQUEUE_LIMIT: usize = 48; -const SELECT_EDU_LIMIT: usize = 16; const CLEANUP_TIMEOUT_MS: u64 = 3500; +const SELECT_PRESENCE_LIMIT: usize = 256; +const SELECT_RECEIPT_LIMIT: usize = 256; +const SELECT_EDU_LIMIT: usize = EDU_LIMIT - 2; +const DEQUEUE_LIMIT: usize = 48; + +pub const PDU_LIMIT: usize = 50; +pub const EDU_LIMIT: usize = 100; + impl Service { #[tracing::instrument(skip_all, name = "sender")] pub(super) async fn sender(&self) -> Result<()> { @@ -216,6 +222,7 @@ impl Service { // Add EDU's into the transaction if let Destination::Normal(server_name) = dest { if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { + debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit"); events.extend(select_edus.into_iter().map(SendingEvent::Edu)); self.db.set_latest_educount(server_name, last_count); } @@ -254,69 +261,176 @@ impl Service { async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu let since = self.db.get_latest_educount(server_name).await; - let mut events = Vec::new(); let mut max_edu_count = since; - let mut device_list_changes = HashSet::new(); + let mut events = Vec::new(); + + self.select_edus_device_changes(server_name, since, &mut max_edu_count, &mut events) + .await; + + if self.server.config.allow_outgoing_read_receipts { + self.select_edus_receipts(server_name, since, &mut max_edu_count, &mut events) + .await; + } + + if self.server.config.allow_outgoing_presence { + self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events) + .await; + } + + Ok((events, max_edu_count)) + } + + /// Look for presence + async fn select_edus_device_changes( + &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, + ) { + debug_assert!(events.len() < SELECT_EDU_LIMIT, "called when edu limit reached"); let server_rooms = self.services.state_cache.server_rooms(server_name); pin_mut!(server_rooms); + let mut device_list_changes = HashSet::::new(); while let Some(room_id) = server_rooms.next().await { - // Look for device list updates in this room - device_list_changes.extend( - self.services - .users - .keys_changed(room_id.as_str(), since, None) - .ready_filter(|user_id| self.services.globals.user_is_local(user_id)) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); - - if self.server.config.allow_outgoing_read_receipts - && !self - .select_edus_receipts(room_id, since, &mut max_edu_count, &mut events) - .await? - { - break; + let keys_changed = self + .services + .users + .room_keys_changed(room_id, since, None) + .ready_filter(|(user_id, _)| self.services.globals.user_is_local(user_id)); + + pin_mut!(keys_changed); + while let Some((user_id, count)) = keys_changed.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); + if !device_list_changes.insert(user_id.into()) { + continue; + } + + // Empty prev id forces synapse to resync; because synapse resyncs, + // we can just insert placeholder data + let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { + user_id: user_id.into(), + device_id: device_id!("placeholder").to_owned(), + device_display_name: Some("Placeholder".to_owned()), + stream_id: uint!(1), + prev_id: Vec::new(), + deleted: None, + keys: None, + }); + + let edu = serde_json::to_vec(&edu).expect("failed to serialize device list update to JSON"); + + events.push(edu); + if events.len() >= SELECT_EDU_LIMIT { + return; + } } } + } + + /// Look for read receipts in this room + async fn select_edus_receipts( + &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, + ) { + debug_assert!(events.len() < EDU_LIMIT, "called when edu limit reached"); + + let server_rooms = self.services.state_cache.server_rooms(server_name); - for user_id in device_list_changes { - // Empty prev id forces synapse to resync; because synapse resyncs, - // we can just insert placeholder data - let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { - user_id, - device_id: device_id!("placeholder").to_owned(), - device_display_name: Some("Placeholder".to_owned()), - stream_id: uint!(1), - prev_id: Vec::new(), - deleted: None, - keys: None, - }); - - events.push(serde_json::to_vec(&edu).expect("json can be serialized")); + pin_mut!(server_rooms); + let mut num = 0; + let mut receipts = BTreeMap::::new(); + while let Some(room_id) = server_rooms.next().await { + let receipt_map = self + .select_edus_receipts_room(room_id, since, max_edu_count, &mut num) + .await; + + if !receipt_map.read.is_empty() { + receipts.insert(room_id.into(), receipt_map); + } } - if self.server.config.allow_outgoing_presence { - self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events) - .await?; + if receipts.is_empty() { + return; } - Ok((events, max_edu_count)) + let receipt_content = Edu::Receipt(ReceiptContent { + receipts, + }); + + let receipt_content = + serde_json::to_vec(&receipt_content).expect("Failed to serialize Receipt EDU to JSON vec"); + + events.push(receipt_content); + } + + /// Look for read receipts in this room + async fn select_edus_receipts_room( + &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, num: &mut usize, + ) -> ReceiptMap { + let receipts = self + .services + .read_receipt + .readreceipts_since(room_id, since); + + pin_mut!(receipts); + let mut read = BTreeMap::::new(); + while let Some((user_id, count, read_receipt)) = receipts.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); + if !self.services.globals.user_is_local(&user_id) { + continue; + } + + let Ok(event) = serde_json::from_str(read_receipt.json().get()) else { + error!(?user_id, ?count, ?read_receipt, "Invalid edu event in read_receipts."); + continue; + }; + + let AnySyncEphemeralRoomEvent::Receipt(r) = event else { + error!(?user_id, ?count, ?event, "Invalid event type in read_receipts"); + continue; + }; + + let (event_id, mut receipt) = r + .content + .0 + .into_iter() + .next() + .expect("we only use one event per read receipt"); + + let receipt = receipt + .remove(&ReceiptType::Read) + .expect("our read receipts always set this") + .remove(&user_id) + .expect("our read receipts always have the user here"); + + let receipt_data = ReceiptData { + data: receipt, + event_ids: vec![event_id.clone()], + }; + + if read.insert(user_id, receipt_data).is_none() { + *num = num.saturating_add(1); + if *num >= SELECT_RECEIPT_LIMIT { + break; + } + } + } + + ReceiptMap { + read, + } } /// Look for presence async fn select_edus_presence( &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, - ) -> Result { + ) { + debug_assert!(events.len() < EDU_LIMIT, "called when edu limit reached"); + let presence_since = self.services.presence.presence_since(since); pin_mut!(presence_since); - let mut presence_updates = Vec::new(); + let mut presence_updates = HashMap::::new(); while let Some((user_id, count, presence_bytes)) = presence_since.next().await { *max_edu_count = cmp::max(count, *max_edu_count); - if !self.services.globals.user_is_local(user_id) { continue; } @@ -330,101 +444,44 @@ impl Service { continue; } - let presence_event = self + let Ok(presence_event) = self .services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id) - .await?; + .from_json_bytes_to_event(presence_bytes, user_id) + .await + .log_err() + else { + continue; + }; - presence_updates.push(PresenceUpdate { - user_id, + let update = PresenceUpdate { + user_id: user_id.into(), presence: presence_event.content.presence, currently_active: presence_event.content.currently_active.unwrap_or(false), + status_msg: presence_event.content.status_msg, last_active_ago: presence_event .content .last_active_ago .unwrap_or_else(|| uint!(0)), - status_msg: presence_event.content.status_msg, - }); + }; - if presence_updates.len() >= SELECT_EDU_LIMIT { + presence_updates.insert(user_id.into(), update); + if presence_updates.len() >= SELECT_PRESENCE_LIMIT { break; } } - if !presence_updates.is_empty() { - let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); - events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); + if presence_updates.is_empty() { + return; } - Ok(true) - } - - /// Look for read receipts in this room - async fn select_edus_receipts( - &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec>, - ) -> Result { - let receipts = self - .services - .read_receipt - .readreceipts_since(room_id, since); - - pin_mut!(receipts); - while let Some((user_id, count, read_receipt)) = receipts.next().await { - *max_edu_count = cmp::max(count, *max_edu_count); - if !self.services.globals.user_is_local(&user_id) { - continue; - } - - let event = serde_json::from_str(read_receipt.json().get()) - .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; - - let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { - let mut read = BTreeMap::new(); - let (event_id, mut receipt) = r - .content - .0 - .into_iter() - .next() - .expect("we only use one event per read receipt"); - - let receipt = receipt - .remove(&ReceiptType::Read) - .expect("our read receipts always set this") - .remove(&user_id) - .expect("our read receipts always have the user here"); - - read.insert( - user_id, - ReceiptData { - data: receipt.clone(), - event_ids: vec![event_id.clone()], - }, - ); - - let receipt_map = ReceiptMap { - read, - }; - - let mut receipts = BTreeMap::new(); - receipts.insert(room_id.to_owned(), receipt_map); - - Edu::Receipt(ReceiptContent { - receipts, - }) - } else { - Error::bad_database("Invalid event type in read_receipts"); - continue; - }; - - events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); + let presence_content = Edu::Presence(PresenceContent { + push: presence_updates.into_values().collect(), + }); - if events.len() >= SELECT_EDU_LIMIT { - return Ok(false); - } - } + let presence_content = serde_json::to_vec(&presence_content).expect("failed to serialize Presence EDU to JSON"); - Ok(true) + events.push(presence_content); } async fn send_events(&self, dest: Destination, events: Vec) -> SendingResult { From ca57dc79288e563e5e090f6699b115bc49b9d27f Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 06:21:57 +0000 Subject: [PATCH 102/245] optimize config denylists Signed-off-by: Jason Volk --- src/api/client/directory.rs | 20 ++++++++------------ src/core/config/mod.rs | 14 +++++++------- src/service/globals/mod.rs | 4 ---- src/service/media/remote.rs | 3 +-- src/service/sending/send.rs | 9 ++++----- 5 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index ea499545c..6cf7b13f5 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -37,14 +37,12 @@ pub(crate) async fn get_public_rooms_filtered_route( ) -> Result { if let Some(server) = &body.server { if services - .globals - .forbidden_remote_room_directory_server_names() + .server + .config + .forbidden_remote_room_directory_server_names .contains(server) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } @@ -77,14 +75,12 @@ pub(crate) async fn get_public_rooms_route( ) -> Result { if let Some(server) = &body.server { if services - .globals - .forbidden_remote_room_directory_server_names() + .server + .config + .forbidden_remote_room_directory_server_names .contains(server) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 23d35424a..59ddd7c78 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -2,7 +2,7 @@ pub mod check; pub mod proxy; use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashSet}, fmt, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::PathBuf, @@ -983,8 +983,8 @@ pub struct Config { /// Vector list of servers that conduwuit will refuse to download remote /// media from. No default. - #[serde(default = "Vec::new")] - pub prevent_media_downloads_from: Vec, + #[serde(default = "HashSet::new")] + pub prevent_media_downloads_from: HashSet, /// List of forbidden server names that we will block incoming AND outgoing /// federation with, and block client room joins / remote user invites. @@ -994,14 +994,14 @@ pub struct Config { /// outbound federation handler. /// /// Basically "global" ACLs. No default. - #[serde(default = "Vec::new")] - pub forbidden_remote_server_names: Vec, + #[serde(default = "HashSet::new")] + pub forbidden_remote_server_names: HashSet, /// List of forbidden server names that we will block all outgoing federated /// room directory requests for. Useful for preventing our users from /// wandering into bad servers or spaces. No default. - #[serde(default = "Vec::new")] - pub forbidden_remote_room_directory_server_names: Vec, + #[serde(default = "HashSet::new")] + pub forbidden_remote_room_directory_server_names: HashSet, /// Vector list of IPv4 and IPv6 CIDR ranges / subnets *in quotes* that you /// do not want conduwuit to send outbound requests to. Defaults to diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 329a6583c..157c39440 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -252,10 +252,6 @@ impl Service { pub fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts } - pub fn forbidden_remote_room_directory_server_names(&self) -> &[OwnedServerName] { - &self.config.forbidden_remote_room_directory_server_names - } - pub fn well_known_support_page(&self) -> &Option { &self.config.well_known.support_page } pub fn well_known_support_role(&self) -> &Option { &self.config.well_known.support_role } diff --git a/src/service/media/remote.rs b/src/service/media/remote.rs index 59846b8ee..1c6c9ca02 100644 --- a/src/service/media/remote.rs +++ b/src/service/media/remote.rs @@ -382,8 +382,7 @@ fn check_fetch_authorized(&self, mxc: &Mxc<'_>) -> Result<()> { .server .config .prevent_media_downloads_from - .iter() - .any(|entry| entry == mxc.server_name) + .contains(mxc.server_name) { // we'll lie to the client and say the blocked server's media was not found and // log. the client has no way of telling anyways so this is a security bonus. diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 73b6a468f..62da59ef2 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,8 +1,8 @@ use std::{fmt::Debug, mem}; use conduit::{ - debug, debug_error, debug_info, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, - Err, Error, Result, + debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error, + Result, }; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; @@ -36,10 +36,9 @@ impl super::Service { .server .config .forbidden_remote_server_names - .contains(&dest.to_owned()) + .contains(dest) { - debug_info!("Refusing to send outbound federation request to {dest}"); - return Err!(Request(Forbidden("Federation with this homeserver is not allowed."))); + return Err!(Request(Forbidden(debug_warn!("Federation with this {dest} is not allowed.")))); } let actual = self.services.resolver.get_actual_dest(dest).await?; From b8260e0104860eee3b8dfcb6f9091e3ad87ae2de Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 06:37:09 +0000 Subject: [PATCH 103/245] optimize for pdu_exists; remove a yield thing Signed-off-by: Jason Volk --- src/service/rooms/event_handler/mod.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 8f96f68e5..24c2692d0 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -359,7 +359,7 @@ impl Service { }; // Skip the PDU if it is redacted and we already have it as an outlier event - if self.services.timeline.get_pdu_json(event_id).await.is_ok() { + if self.services.timeline.pdu_exists(event_id).await { return Err!(Request(InvalidParam("Event was redacted and we already knew about it"))); } @@ -1123,7 +1123,6 @@ impl Service { let mut todo_auth_events = vec![Arc::clone(id)]; let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i: u64 = 0; while let Some(next_id) = todo_auth_events.pop() { if let Some((time, tries)) = self .services @@ -1146,12 +1145,7 @@ impl Service { continue; } - i = i.saturating_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - - if self.services.timeline.get_pdu(&next_id).await.is_ok() { + if self.services.timeline.pdu_exists(&next_id).await { trace!("Found {next_id} in db"); continue; } From dd6621a720b03ca18a7e0fca6881923c373e3cac Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 07:07:42 +0000 Subject: [PATCH 104/245] reduce unnecessary clone in pdu handler Signed-off-by: Jason Volk --- src/api/server/send.rs | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/api/server/send.rs b/src/api/server/send.rs index e2100a0f5..4f5260521 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -118,15 +118,12 @@ async fn handle_pdus( .lock(&room_id) .await; - resolved_map.insert( - event_id.clone(), - services - .rooms - .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, true) - .await - .map(|_| ()), - ); + let result = services + .rooms + .event_handler + .handle_incoming_pdu(origin, &room_id, &event_id, value, true) + .await + .map(|_| ()); drop(mutex_lock); debug!( @@ -134,12 +131,14 @@ async fn handle_pdus( txn_elapsed = ?txn_start_time.elapsed(), "Finished PDU {event_id}", ); + + resolved_map.insert(event_id, result); } - for pdu in &resolved_map { - if let Err(e) = pdu.1 { + for (id, result) in &resolved_map { + if let Err(e) = result { if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { - warn!("Incoming PDU failed {pdu:?}"); + warn!("Incoming PDU failed {id}: {e:?}"); } } } From b08c1241a89514046f666fc21f817af2feb8bce2 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 07:15:28 +0000 Subject: [PATCH 105/245] add some interruption points in recursive event handling to prevent shutdown hangs Signed-off-by: Jason Volk --- src/api/server/send.rs | 7 ++++--- src/core/server.rs | 9 ++++++++- src/service/rooms/event_handler/mod.rs | 3 +++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 4f5260521..d5d3ffbbf 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -71,7 +71,7 @@ pub(crate) async fn send_transaction_message_route( "Starting txn", ); - let resolved_map = handle_pdus(&services, &client, &body.pdus, origin, &txn_start_time).await; + let resolved_map = handle_pdus(&services, &client, &body.pdus, origin, &txn_start_time).await?; handle_edus(&services, &client, &body.edus, origin).await; debug!( @@ -93,7 +93,7 @@ pub(crate) async fn send_transaction_message_route( async fn handle_pdus( services: &Services, _client: &IpAddr, pdus: &[Box], origin: &ServerName, txn_start_time: &Instant, -) -> ResolvedMap { +) -> Result { let mut parsed_pdus = Vec::with_capacity(pdus.len()); for pdu in pdus { parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { @@ -110,6 +110,7 @@ async fn handle_pdus( let mut resolved_map = BTreeMap::new(); for (event_id, value, room_id) in parsed_pdus { + services.server.check_running()?; let pdu_start_time = Instant::now(); let mutex_lock = services .rooms @@ -143,7 +144,7 @@ async fn handle_pdus( } } - resolved_map + Ok(resolved_map) } async fn handle_edus(services: &Services, client: &IpAddr, edus: &[Raw], origin: &ServerName) { diff --git a/src/core/server.rs b/src/core/server.rs index 89f1dea58..627e125d6 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -5,7 +5,7 @@ use std::{ use tokio::{runtime, sync::broadcast}; -use crate::{config::Config, log::Log, metrics::Metrics, Err, Result}; +use crate::{config::Config, err, log::Log, metrics::Metrics, Err, Result}; /// Server runtime state; public portion pub struct Server { @@ -107,6 +107,13 @@ impl Server { .expect("runtime handle available in Server") } + #[inline] + pub fn check_running(&self) -> Result { + self.running() + .then_some(()) + .ok_or_else(|| err!(debug_warn!("Server is shutting down."))) + } + #[inline] pub fn running(&self) -> bool { !self.stopping.load(Ordering::Acquire) } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 24c2692d0..0b2bbf731 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -205,6 +205,7 @@ impl Service { debug!(events = ?sorted_prev_events, "Got previous events"); for prev_id in sorted_prev_events { + self.services.server.check_running()?; match self .handle_prev_pdu( origin, @@ -1268,6 +1269,8 @@ impl Service { let mut amount = 0; while let Some(prev_event_id) = todo_outlier_stack.pop() { + self.services.server.check_running()?; + if let Some((pdu, mut json_opt)) = self .fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id) .boxed() From 339654216857dc5caa492f0b9a0aa442af56c0f9 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 09:09:20 +0000 Subject: [PATCH 106/245] complete the example-config generator macro Signed-off-by: Jason Volk --- src/macros/config.rs | 155 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 148 insertions(+), 7 deletions(-) diff --git a/src/macros/config.rs b/src/macros/config.rs index 6d29c21fa..3c93bd087 100644 --- a/src/macros/config.rs +++ b/src/macros/config.rs @@ -1,11 +1,19 @@ -use std::fmt::Write; +use std::{fmt::Write as _, fs::File, io::Write as _}; use proc_macro::TokenStream; +use proc_macro2::Span; use quote::ToTokens; -use syn::{Expr, ExprLit, Field, Fields, FieldsNamed, ItemStruct, Lit, Meta, MetaNameValue, Type, TypePath}; +use syn::{ + parse::Parser, punctuated::Punctuated, Error, Expr, ExprLit, Field, Fields, FieldsNamed, ItemStruct, Lit, Meta, + MetaList, MetaNameValue, Type, TypePath, +}; use crate::{utils::is_cargo_build, Result}; +const UNDOCUMENTED: &str = "# This item is undocumented. Please contribute documentation for it."; +const HEADER: &str = "## Conduwuit Configuration\n##\n## THIS FILE IS GENERATED. Changes to documentation and \ + defaults must\n## be made within the code found at src/core/config/\n"; + #[allow(clippy::needless_pass_by_value)] pub(super) fn example_generator(input: ItemStruct, args: &[Meta]) -> Result { if is_cargo_build() { @@ -18,6 +26,12 @@ pub(super) fn example_generator(input: ItemStruct, args: &[Meta]) -> Result Result<()> { + let mut file = File::create("conduwuit-example.toml") + .map_err(|e| Error::new(Span::call_site(), format!("Failed to open config file for generation: {e}")))?; + + file.write_all(HEADER.as_bytes()) + .expect("written to config file"); + if let Fields::Named(FieldsNamed { named, .. @@ -28,21 +42,143 @@ fn generate_example(input: &ItemStruct, _args: &[Meta]) -> Result<()> { continue; }; - let Some(doc) = get_doc_comment(field) else { + let Some(type_name) = get_type_name(field) else { continue; }; - let Some(type_name) = get_type_name(field) else { - continue; + let doc = get_doc_comment(field) + .unwrap_or_else(|| UNDOCUMENTED.into()) + .trim_end() + .to_owned(); + + let doc = if doc.ends_with('#') { + format!("{doc}\n") + } else { + format!("{doc}\n#\n") + }; + + let default = get_doc_default(field) + .or_else(|| get_default(field)) + .unwrap_or_default(); + + let default = if !default.is_empty() { + format!(" {default}") + } else { + default }; - //println!("{:?} {type_name:?}\n{doc}", ident.to_string()); + file.write_fmt(format_args!("\n{doc}")) + .expect("written to config file"); + + file.write_fmt(format_args!("#{ident} ={default}\n")) + .expect("written to config file"); } } Ok(()) } +fn get_default(field: &Field) -> Option { + for attr in &field.attrs { + let Meta::List(MetaList { + path, + tokens, + .. + }) = &attr.meta + else { + continue; + }; + + if !path + .segments + .iter() + .next() + .is_some_and(|s| s.ident == "serde") + { + continue; + } + + let Some(arg) = Punctuated::::parse_terminated + .parse(tokens.clone().into()) + .ok()? + .iter() + .next() + .cloned() + else { + continue; + }; + + match arg { + Meta::NameValue(MetaNameValue { + value: Expr::Lit(ExprLit { + lit: Lit::Str(str), + .. + }), + .. + }) => { + match str.value().as_str() { + "HashSet::new" | "Vec::new" | "RegexSet::empty" => Some("[]".to_owned()), + "true_fn" => return Some("true".to_owned()), + _ => return None, + }; + }, + Meta::Path { + .. + } => return Some("false".to_owned()), + _ => return None, + }; + } + + None +} + +fn get_doc_default(field: &Field) -> Option { + for attr in &field.attrs { + let Meta::NameValue(MetaNameValue { + path, + value, + .. + }) = &attr.meta + else { + continue; + }; + + if !path + .segments + .iter() + .next() + .is_some_and(|s| s.ident == "doc") + { + continue; + } + + let Expr::Lit(ExprLit { + lit, + .. + }) = &value + else { + continue; + }; + + let Lit::Str(token) = &lit else { + continue; + }; + + let value = token.value(); + if !value.trim().starts_with("default:") { + continue; + } + + return value + .split_once(':') + .map(|(_, v)| v) + .map(str::trim) + .map(ToOwned::to_owned); + } + + None +} + fn get_doc_comment(field: &Field) -> Option { let mut out = String::new(); for attr in &field.attrs { @@ -76,7 +212,12 @@ fn get_doc_comment(field: &Field) -> Option { continue; }; - writeln!(&mut out, "# {}", token.value()).expect("wrote to output string buffer"); + let value = token.value(); + if value.trim().starts_with("default:") { + continue; + } + + writeln!(&mut out, "#{value}").expect("wrote to output string buffer"); } (!out.is_empty()).then_some(out) From 367d1533801d8d8c0b53aa07992cbac5f267db5c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 10:09:14 +0000 Subject: [PATCH 107/245] add default-directives to config document comments Signed-off-by: Jason Volk --- src/core/config/mod.rs | 249 ++++++++++++++++++++++++----------------- 1 file changed, 144 insertions(+), 105 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 59ddd7c78..919bb4862 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -54,9 +54,9 @@ pub struct Config { /// want this to be localhost (127.0.0.1 / ::1). If you are using Docker or /// a container NAT networking setup, you likely need this to be 0.0.0.0. /// To listen multiple addresses, specify a vector e.g. ["127.0.0.1", "::1"] + /// Default if unspecified is both IPv4 and IPv6 localhost. /// - /// default if unspecified is both IPv4 and IPv6 localhost: ["127.0.0.1", - /// "::1"] + /// default: ["127.0.0.1", "::1"] #[serde(default = "default_address")] address: ListeningAddr, @@ -67,7 +67,7 @@ pub struct Config { /// port to this. To listen on multiple ports, specify a vector e.g. [8080, /// 8448] /// - /// default if unspecified is 8008 + /// default: 8008 #[serde(default = "default_port")] port: ListeningPort, @@ -80,9 +80,11 @@ pub struct Config { /// (666 minimum). pub unix_socket_path: Option, + /// default: 660 #[serde(default = "default_unix_socket_perms")] pub unix_socket_perms: u32, + /// default: rocksdb #[serde(default = "default_database_backend")] pub database_backend: String, @@ -98,7 +100,9 @@ pub struct Config { /// Set this to any float value in megabytes for conduwuit to tell the /// database engine that this much memory is available for database-related /// caches. May be useful if you have significant memory to spare to - /// increase performance. Defaults to 256.0 + /// increase performance. + /// + /// default: 256.0 #[serde(default = "default_db_cache_capacity_mb")] pub db_cache_capacity_mb: f64, @@ -107,6 +111,8 @@ pub struct Config { /// lightning bolt emoji option, just replaced with support for adding your /// own custom text or emojis. To disable, set this to "" (an empty string) /// Defaults to "🏳️⚧️" (trans pride flag) + /// + /// default: 🏳️⚧️ #[serde(default = "default_new_user_displayname_suffix")] pub new_user_displayname_suffix: String, @@ -123,11 +129,10 @@ pub struct Config { /// Set this to any float value to multiply conduwuit's in-memory LRU caches /// with. May be useful if you have significant memory to spare to increase - /// performance. - /// - /// This was previously called `conduit_cache_capacity_modifier` + /// performance. This was previously called + /// `conduit_cache_capacity_modifier`. /// - /// Defaults to 1.0. + /// default: 1.0. #[serde(default = "default_cache_capacity_modifier", alias = "conduit_cache_capacity_modifier")] pub cache_capacity_modifier: f64, @@ -197,11 +202,9 @@ pub struct Config { pub dns_tcp_fallback: bool, /// Enable to query all nameservers until the domain is found. Referred to - /// as "trust_negative_responses" in hickory_reso> This can avoid useless - /// DNS queries if the first nameserver responds with NXDOMAIN or an empty - /// NOERROR response. - /// - /// The default is to query one nameserver and stop (false). + /// as "trust_negative_responses" in hickory_resolver. This can avoid + /// useless DNS queries if the first nameserver responds with NXDOMAIN or + /// an empty NOERROR response. #[serde(default = "true_fn")] pub query_all_nameservers: bool, @@ -230,116 +233,121 @@ pub struct Config { /// /// Defaults to 5 - Ipv4ThenIpv6 as this is the most compatible and IPv4 /// networking is currently the most prevalent. + /// + /// default: 5 #[serde(default = "default_ip_lookup_strategy")] pub ip_lookup_strategy: u8, /// Max request size for file uploads + /// + /// default: 20971520 #[serde(default = "default_max_request_size")] pub max_request_size: usize, #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, - /// Default/base connection timeout. - /// This is used only by URL previews and update/news endpoint checks + /// Default/base connection timeout (seconds). This is used only by URL + /// previews and update/news endpoint checks. /// - /// Defaults to 10 seconds + /// default: 10 #[serde(default = "default_request_conn_timeout")] pub request_conn_timeout: u64, - /// Default/base request timeout. The time waiting to receive more data from - /// another server. This is used only by URL previews, update/news, and - /// misc endpoint checks + /// Default/base request timeout (seconds). The time waiting to receive more + /// data from another server. This is used only by URL previews, + /// update/news, and misc endpoint checks. /// - /// Defaults to 35 seconds + /// default: 35 #[serde(default = "default_request_timeout")] pub request_timeout: u64, - /// Default/base request total timeout. The time limit for a whole request. - /// This is set very high to not cancel healthy requests while serving as a - /// backstop. This is used only by URL previews and update/news endpoint - /// checks + /// Default/base request total timeout (seconds). The time limit for a whole + /// request. This is set very high to not cancel healthy requests while + /// serving as a backstop. This is used only by URL previews and + /// update/news endpoint checks. /// - /// Defaults to 320 seconds + /// default: 320 #[serde(default = "default_request_total_timeout")] pub request_total_timeout: u64, - /// Default/base idle connection pool timeout - /// This is used only by URL previews and update/news endpoint checks + /// Default/base idle connection pool timeout (seconds). This is used only + /// by URL previews and update/news endpoint checks. /// - /// Defaults to 5 seconds + /// default: 5 #[serde(default = "default_request_idle_timeout")] pub request_idle_timeout: u64, - /// Default/base max idle connections per host - /// This is used only by URL previews and update/news endpoint checks + /// Default/base max idle connections per host. This is used only by URL + /// previews and update/news endpoint checks. Defaults to 1 as generally the + /// same open connection can be re-used. /// - /// Defaults to 1 as generally the same open connection can be re-used + /// default: 1 #[serde(default = "default_request_idle_per_host")] pub request_idle_per_host: u16, - /// Federation well-known resolution connection timeout + /// Federation well-known resolution connection timeout (seconds) /// - /// Defaults to 6 seconds + /// default: 6 #[serde(default = "default_well_known_conn_timeout")] pub well_known_conn_timeout: u64, - /// Federation HTTP well-known resolution request timeout + /// Federation HTTP well-known resolution request timeout (seconds) /// - /// Defaults to 10 seconds + /// default: 10 #[serde(default = "default_well_known_timeout")] pub well_known_timeout: u64, - /// Federation client request timeout - /// You most definitely want this to be high to account for extremely large - /// room joins, slow homeservers, your own resources etc. + /// Federation client request timeout (seconds). You most definitely want + /// this to be high to account for extremely large room joins, slow + /// homeservers, your own resources etc. /// - /// Defaults to 300 seconds + /// default: 300 #[serde(default = "default_federation_timeout")] pub federation_timeout: u64, - /// Federation client idle connection pool timeout + /// Federation client idle connection pool timeout (seconds) /// - /// Defaults to 25 seconds + /// default: 25 #[serde(default = "default_federation_idle_timeout")] pub federation_idle_timeout: u64, - /// Federation client max idle connections per host + /// Federation client max idle connections per host. Defaults to 1 as + /// generally the same open connection can be re-used /// - /// Defaults to 1 as generally the same open connection can be re-used + /// default: 1 #[serde(default = "default_federation_idle_per_host")] pub federation_idle_per_host: u16, - /// Federation sender request timeout - /// The time it takes for the remote server to process sent transactions can - /// take a while. + /// Federation sender request timeout (seconds). The time it takes for the + /// remote server to process sent transactions can take a while. /// - /// Defaults to 180 seconds + /// default: 180 #[serde(default = "default_sender_timeout")] pub sender_timeout: u64, - /// Federation sender idle connection pool timeout + /// Federation sender idle connection pool timeout (seconds) /// - /// Defaults to 180 seconds + /// default: 180 #[serde(default = "default_sender_idle_timeout")] pub sender_idle_timeout: u64, - /// Federation sender transaction retry backoff limit + /// Federation sender transaction retry backoff limit (seconds) /// - /// Defaults to 86400 seconds + /// default: 86400 #[serde(default = "default_sender_retry_backoff_limit")] pub sender_retry_backoff_limit: u64, - /// Appservice URL request connection timeout + /// Appservice URL request connection timeout. Defaults to 35 seconds as + /// generally appservices are hosted within the same network. /// - /// Defaults to 35 seconds as generally appservices are hosted within the - /// same network + /// default: 35 #[serde(default = "default_appservice_timeout")] pub appservice_timeout: u64, - /// Appservice URL idle connection pool timeout + /// Appservice URL idle connection pool timeout (seconds) /// - /// Defaults to 300 seconds + /// default: 300 #[serde(default = "default_appservice_idle_timeout")] pub appservice_idle_timeout: u64, @@ -377,12 +385,11 @@ pub struct Config { /// no default pub registration_token_file: Option, - /// controls whether encrypted rooms and events are allowed (default true) + /// Controls whether encrypted rooms and events are allowed. #[serde(default = "true_fn")] pub allow_encryption: bool, - /// controls whether federation is allowed or not - /// defaults to true + /// Controls whether federation is allowed or not. #[serde(default = "true_fn")] pub allow_federation: bool, @@ -487,6 +494,8 @@ pub struct Config { /// /// (Currently, conduwuit doesn't support batched key requests, so this list /// should only contain other Synapse servers) Defaults to `matrix.org` + /// + /// default: ["matrix.org"] #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec, @@ -527,13 +536,13 @@ pub struct Config { /// binary from trace macros. For debug builds, this restriction is not /// applied. /// - /// Defaults to "info" + /// default: "info" #[serde(default = "default_log")] pub log: String, /// controls whether logs will be outputted with ANSI colours /// - /// defaults to true + /// default: true #[serde(default = "true_fn", alias = "log_colours")] pub log_colors: bool, @@ -542,7 +551,7 @@ pub struct Config { /// These are the OpenID tokens that are primarily used for Matrix account /// integrations, *not* OIDC/OpenID Connect/etc /// - /// Defaults to 3600 (1 hour) + /// default: 3600 #[serde(default = "default_openid_token_ttl")] pub openid_token_ttl: u64, @@ -585,9 +594,9 @@ pub struct Config { /// no default pub turn_secret_file: Option, - /// TURN TTL + /// TURN TTL in seconds /// - /// Default is 86400 seconds + /// default: 86400 #[serde(default = "default_turn_ttl")] pub turn_ttl: u64, @@ -629,10 +638,14 @@ pub struct Config { pub rocksdb_log_stderr: bool, /// Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB. + /// + /// default: 4194304 #[serde(default = "default_rocksdb_max_log_file_size")] pub rocksdb_max_log_file_size: usize, - /// Time in seconds before RocksDB will forcibly rotate logs. Defaults to 0. + /// Time in seconds before RocksDB will forcibly rotate logs. + /// + /// default: 0 #[serde(default = "default_rocksdb_log_time_to_roll")] pub rocksdb_log_time_to_roll: usize, @@ -649,8 +662,6 @@ pub struct Config { /// RocksDB issues, try enabling this option as it turns off Direct IO and /// feel free to report in the conduwuit Matrix room if this option fixes /// your DB issues. See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. - /// - /// Defaults to false #[serde(default)] pub rocksdb_optimize_for_spinning_disks: bool, @@ -662,14 +673,16 @@ pub struct Config { /// Amount of threads that RocksDB will use for parallelism on database /// operatons such as cleanup, sync, flush, compaction, etc. Set to 0 to use - /// all your logical threads. + /// all your logical threads. Defaults to your CPU logical thread count. /// - /// Defaults to your CPU logical thread count. + /// default: 0 #[serde(default = "default_rocksdb_parallelism_threads")] pub rocksdb_parallelism_threads: usize, /// Maximum number of LOG files RocksDB will keep. This must *not* be set to /// 0. It must be at least 1. Defaults to 3 as these are not very useful. + /// + /// default: 3 #[serde(default = "default_rocksdb_max_log_files")] pub rocksdb_max_log_files: usize, @@ -682,7 +695,7 @@ pub struct Config { /// /// "none" will disable compression. /// - /// Defaults to "zstd" + /// default: "zstd" #[serde(default = "default_rocksdb_compression_algo")] pub rocksdb_compression_algo: String, @@ -746,6 +759,8 @@ pub struct Config { /// See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information /// /// Defaults to 1 (TolerateCorruptedTailRecords) + /// + /// default: 1 #[serde(default = "default_rocksdb_recovery_mode")] pub rocksdb_recovery_mode: u8, @@ -760,8 +775,6 @@ pub struct Config { /// repair. /// - Disabling repair mode and restarting the server is recommended after /// running the repair. - /// - /// Defaults to false #[serde(default)] pub rocksdb_repair: bool, @@ -798,6 +811,8 @@ pub struct Config { /// 6 = All statistics. /// /// Defaults to 1 (No statistics, except in debug-mode) + /// + /// default: 1 #[serde(default = "default_rocksdb_stats_level")] pub rocksdb_stats_level: u8, @@ -831,11 +846,15 @@ pub struct Config { /// Config option to control how many seconds before presence updates that /// you are idle. Defaults to 5 minutes. + /// + /// default: 300 #[serde(default = "default_presence_idle_timeout_s")] pub presence_idle_timeout_s: u64, /// Config option to control how many seconds before presence updates that /// you are offline. Defaults to 30 minutes. + /// + /// default: 1800 #[serde(default = "default_presence_offline_timeout_s")] pub presence_offline_timeout_s: u64, @@ -843,42 +862,46 @@ pub struct Config { /// Disabling is offered as an optimization for servers participating in /// many large rooms or when resources are limited. Disabling it may cause /// incorrect presence states (i.e. stuck online) to be seen for some - /// remote users. Defaults to true. + /// remote users. #[serde(default = "true_fn")] pub presence_timeout_remote_users: bool, /// Config option to control whether we should receive remote incoming read - /// receipts. Defaults to true. + /// receipts. #[serde(default = "true_fn")] pub allow_incoming_read_receipts: bool, /// Config option to control whether we should send read receipts to remote - /// servers. Defaults to true. + /// servers. #[serde(default = "true_fn")] pub allow_outgoing_read_receipts: bool, - /// Config option to control outgoing typing updates to federation. Defaults - /// to true. + /// Config option to control outgoing typing updates to federation. #[serde(default = "true_fn")] pub allow_outgoing_typing: bool, /// Config option to control incoming typing updates from federation. - /// Defaults to true. #[serde(default = "true_fn")] pub allow_incoming_typing: bool, /// Config option to control maximum time federation user can indicate /// typing. + /// + /// default: 30 #[serde(default = "default_typing_federation_timeout_s")] pub typing_federation_timeout_s: u64, /// Config option to control minimum time local client can indicate typing. /// This does not override a client's request to stop typing. It only /// enforces a minimum value in case of no stop request. + /// + /// default: 15 #[serde(default = "default_typing_client_timeout_min_s")] pub typing_client_timeout_min_s: u64, /// Config option to control maximum time local client can indicate typing. + /// + /// default: 45 #[serde(default = "default_typing_client_timeout_max_s")] pub typing_client_timeout_max_s: u64, @@ -910,7 +933,7 @@ pub struct Config { pub brotli_compression: bool, /// Set to true to allow user type "guest" registrations. Element attempts - /// to register guest users automatically. Defaults to false + /// to register guest users automatically. Defaults to false. #[serde(default)] pub allow_guest_registration: bool, @@ -920,7 +943,7 @@ pub struct Config { pub log_guest_registrations: bool, /// Set to true to allow guest registrations/users to auto join any rooms - /// specified in `auto_join_rooms` Defaults to false + /// specified in `auto_join_rooms` Defaults to false. #[serde(default)] pub allow_guests_auto_join_rooms: bool, @@ -964,9 +987,7 @@ pub struct Config { /// is now disabled by default. You may still return to upstream Conduit /// but you have to run Conduwuit at least once with this set to true and /// allow the media_startup_check to take place before shutting - /// down to return to Conduit. - /// - /// Disabled by default. + /// down to return to Conduit. Disabled by default. #[serde(default)] pub media_compat_file_link: bool, @@ -975,9 +996,7 @@ pub struct Config { /// corresponding entries will be removed from the database. This is /// disabled by default because if the media directory is accidentally moved /// or inaccessible the metadata entries in the database will be lost with - /// sadness. - /// - /// Disabled by default. + /// sadness. Disabled by default. #[serde(default)] pub prune_missing_media: bool, @@ -1008,12 +1027,35 @@ pub struct Config { /// RFC1918, unroutable, loopback, multicast, and testnet addresses for /// security. /// - /// To disable, set this to be an empty vector (`[]`). /// Please be aware that this is *not* a guarantee. You should be using a /// firewall with zones as doing this on the application layer may have /// bypasses. /// /// Currently this does not account for proxies in use like Synapse does. + /// + /// To disable, set this to be an empty vector (`[]`). + /// The default is: + /// [ + /// "127.0.0.0/8", + /// "10.0.0.0/8", + /// "172.16.0.0/12", + /// "192.168.0.0/16", + /// "100.64.0.0/10", + /// "192.0.0.0/24", + /// "169.254.0.0/16", + /// "192.88.99.0/24", + /// "198.18.0.0/15", + /// "192.0.2.0/24", + /// "198.51.100.0/24", + /// "203.0.113.0/24", + /// "224.0.0.0/4", + /// "::1/128", + /// "fe80::/10", + /// "fc00::/7", + /// "2001:db8::/32", + /// "ff00::/8", + /// "fec0::/10", + /// ] #[serde(default = "default_ip_range_denylist")] pub ip_range_denylist: Vec, @@ -1060,7 +1102,9 @@ pub struct Config { pub url_preview_url_contains_allowlist: Vec, /// Maximum amount of bytes allowed in a URL preview body size when - /// spidering. Defaults to 384KB (384_000 bytes) + /// spidering. Defaults to 384KB. + /// + /// defaukt: 384000 #[serde(default = "default_url_preview_max_spider_size")] pub url_preview_max_spider_size: usize, @@ -1109,27 +1153,27 @@ pub struct Config { /// reattempt every message without trimming the queues; this may consume /// significant disk. Set this value to 0 to drop all messages without any /// attempt at redelivery. + /// + /// default: 50 #[serde(default = "default_startup_netburst_keep")] pub startup_netburst_keep: i64, /// controls whether non-admin local users are forbidden from sending room /// invites (local and remote), and if non-admin users can receive remote /// room invites. admins are always allowed to send and receive all room - /// invites. defaults to false + /// invites. #[serde(default)] pub block_non_admin_invites: bool, /// Allows admins to enter commands in rooms other than #admins by prefixing /// with \!admin. The reply will be publicly visible to the room, - /// originating from the sender. defaults to true + /// originating from the sender. #[serde(default = "true_fn")] pub admin_escape_commands: bool, /// Controls whether the conduwuit admin room console / CLI will immediately /// activate on startup. This option can also be enabled with `--console` - /// conduwuit argument - /// - /// Defaults to false + /// conduwuit argument. #[serde(default)] pub admin_console_automatic: bool, @@ -1145,21 +1189,20 @@ pub struct Config { /// Such example could be: `./conduwuit --execute "server admin-notice /// conduwuit has started up at $(date)"` /// - /// Defaults to nothing. + /// default: [] #[serde(default)] pub admin_execute: Vec, /// Controls whether conduwuit should error and fail to start if an admin - /// execute command (`--execute` / `admin_execute`) fails - /// - /// Defaults to false + /// execute command (`--execute` / `admin_execute`) fails. #[serde(default)] pub admin_execute_errors_ignore: bool, /// Controls the max log level for admin command log captures (logs - /// generated from running admin commands) + /// generated from running admin commands). Defaults to "info" on release + /// builds, else "debug" on debug builds. /// - /// Defaults to "info" on release builds, else "debug" on debug builds + /// default: "info" #[serde(default = "default_admin_log_capture")] pub admin_log_capture: String, @@ -1169,8 +1212,6 @@ pub struct Config { /// Sentry.io crash/panic reporting, performance monitoring/metrics, etc. /// This is NOT enabled by default. conduwuit's default Sentry reporting /// endpoint is o4506996327251968.ingest.us.sentry.io - /// - /// Defaults to *false* #[serde(default)] pub sentry: bool, @@ -1182,8 +1223,6 @@ pub struct Config { pub sentry_endpoint: Option, /// Report your Conduwuit server_name in Sentry.io crash reports and metrics - /// - /// Defaults to false #[serde(default)] pub sentry_send_server_name: bool, @@ -1191,9 +1230,9 @@ pub struct Config { /// /// Note that too high values may impact performance, and can be disabled by /// setting it to 0.0 (0%) This value is read as a percentage to Sentry, - /// represented as a decimal + /// represented as a decimal. Defaults to 15% of traces (0.15) /// - /// Defaults to 15% of traces (0.15) + /// default: 0.15 #[serde(default = "default_sentry_traces_sample_rate")] pub sentry_traces_sample_rate: f32, From 5cb0a5f67668828b7c47b8a8efc3f8c834c1d7f2 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 22 Oct 2024 22:16:59 +0000 Subject: [PATCH 108/245] add config generator controls via attribute metadatas Signed-off-by: Jason Volk --- src/core/config/mod.rs | 45 +++++++++++++++++++- src/macros/config.rs | 93 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 127 insertions(+), 11 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 919bb4862..ff2144200 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -28,10 +28,19 @@ use self::proxy::ProxyConfig; use crate::{err, error::Error, utils::sys, Result}; /// all the config options for conduwuit -#[config_example_generator] -#[derive(Clone, Debug, Deserialize)] #[allow(clippy::struct_excessive_bools)] #[allow(rustdoc::broken_intra_doc_links, rustdoc::bare_urls)] +#[derive(Clone, Debug, Deserialize)] +#[config_example_generator( + filename = "conduwuit-example.toml", + section = "global", + undocumented = "# This item is undocumented. Please contribute documentation for it.", + header = "### Conduwuit Configuration\n###\n### THIS FILE IS GENERATED. YOUR CHANGES WILL BE OVERWRITTEN!\n### \ + You should rename this file before configuring your server. Changes\n### to documentation and defaults \ + can be contributed in sourcecode at\n### src/core/config/mod.rs. This file is generated when \ + building.\n###\n", + ignore = "catchall well_known tls" +)] pub struct Config { /// The server_name is the pretty name of this server. It is used as a /// suffix for user and room ids. Examples: matrix.org, conduit.rs @@ -71,6 +80,7 @@ pub struct Config { #[serde(default = "default_port")] port: ListeningPort, + // external structure; separate section pub tls: Option, /// Uncomment unix_socket_path to listen on a UNIX socket at the specified @@ -458,15 +468,18 @@ pub struct Config { #[serde(default = "true_fn")] pub allow_unstable_room_versions: bool, + /// default: 10 #[serde(default = "default_default_room_version")] pub default_room_version: RoomVersionId, + // external structure; separate section #[serde(default)] pub well_known: WellKnownConfig, #[serde(default)] pub allow_jaeger: bool, + /// default: "info" #[serde(default = "default_jaeger_filter")] pub jaeger_filter: String, @@ -478,12 +491,38 @@ pub struct Config { #[serde(default)] pub tracing_flame: bool, + /// default: "info" #[serde(default = "default_tracing_flame_filter")] pub tracing_flame_filter: String, + /// default: "./tracing.folded" #[serde(default = "default_tracing_flame_output_path")] pub tracing_flame_output_path: String, + /// Examples: + /// - No proxy (default): + /// proxy ="none" + /// + /// - For global proxy, create the section at the bottom of this file: + /// [global.proxy] + /// global = { url = "socks5h://localhost:9050" } + /// + /// - To proxy some domains: + /// [global.proxy] + /// [[global.proxy.by_domain]] + /// url = "socks5h://localhost:9050" + /// include = ["*.onion", "matrix.myspecial.onion"] + /// exclude = ["*.myspecial.onion"] + /// + /// Include vs. Exclude: + /// - If include is an empty list, it is assumed to be `["*"]`. + /// - If a domain matches both the exclude and include list, the proxy will + /// only be used if it was included because of a more specific rule than + /// it was excluded. In the above example, the proxy would be used for + /// `ordinary.onion`, `matrix.myspecial.onion`, but not + /// `hello.myspecial.onion`. + /// + /// default: "none" #[serde(default)] pub proxy: ProxyConfig, @@ -1278,6 +1317,7 @@ pub struct Config { } #[derive(Clone, Debug, Deserialize)] +#[config_example_generator(filename = "conduwuit-example.toml", section = "global.tls")] pub struct TlsConfig { pub certs: String, pub key: String, @@ -1287,6 +1327,7 @@ pub struct TlsConfig { } #[derive(Clone, Debug, Deserialize, Default)] +#[config_example_generator(filename = "conduwuit-example.toml", section = "global.well_known")] pub struct WellKnownConfig { pub client: Option, pub server: Option, diff --git a/src/macros/config.rs b/src/macros/config.rs index 3c93bd087..f86163520 100644 --- a/src/macros/config.rs +++ b/src/macros/config.rs @@ -1,18 +1,21 @@ -use std::{fmt::Write as _, fs::File, io::Write as _}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Write as _, + fs::OpenOptions, + io::Write as _, +}; use proc_macro::TokenStream; use proc_macro2::Span; use quote::ToTokens; use syn::{ - parse::Parser, punctuated::Punctuated, Error, Expr, ExprLit, Field, Fields, FieldsNamed, ItemStruct, Lit, Meta, - MetaList, MetaNameValue, Type, TypePath, + parse::Parser, punctuated::Punctuated, spanned::Spanned, Error, Expr, ExprLit, Field, Fields, FieldsNamed, + ItemStruct, Lit, Meta, MetaList, MetaNameValue, Type, TypePath, }; use crate::{utils::is_cargo_build, Result}; const UNDOCUMENTED: &str = "# This item is undocumented. Please contribute documentation for it."; -const HEADER: &str = "## Conduwuit Configuration\n##\n## THIS FILE IS GENERATED. Changes to documentation and \ - defaults must\n## be made within the code found at src/core/config/\n"; #[allow(clippy::needless_pass_by_value)] pub(super) fn example_generator(input: ItemStruct, args: &[Meta]) -> Result { @@ -25,11 +28,41 @@ pub(super) fn example_generator(input: ItemStruct, args: &[Meta]) -> Result Result<()> { - let mut file = File::create("conduwuit-example.toml") +fn generate_example(input: &ItemStruct, args: &[Meta]) -> Result<()> { + let settings = get_settings(args); + + let filename = settings + .get("filename") + .ok_or_else(|| Error::new(args[0].span(), "missing required 'filename' attribute argument"))?; + + let undocumented = settings + .get("undocumented") + .map_or(UNDOCUMENTED, String::as_str); + + let ignore: HashSet<&str> = settings + .get("ignore") + .map_or("", String::as_str) + .split(' ') + .collect(); + + let section = settings + .get("section") + .ok_or_else(|| Error::new(args[0].span(), "missing required 'section' attribute argument"))?; + + let mut file = OpenOptions::new() + .write(true) + .create(section == "global") + .truncate(section == "global") + .append(section != "global") + .open(filename) .map_err(|e| Error::new(Span::call_site(), format!("Failed to open config file for generation: {e}")))?; - file.write_all(HEADER.as_bytes()) + if let Some(header) = settings.get("header") { + file.write_all(header.as_bytes()) + .expect("written to config file"); + } + + file.write_fmt(format_args!("\n[{section}]\n")) .expect("written to config file"); if let Fields::Named(FieldsNamed { @@ -42,12 +75,16 @@ fn generate_example(input: &ItemStruct, _args: &[Meta]) -> Result<()> { continue; }; + if ignore.contains(ident.to_string().as_str()) { + continue; + } + let Some(type_name) = get_type_name(field) else { continue; }; let doc = get_doc_comment(field) - .unwrap_or_else(|| UNDOCUMENTED.into()) + .unwrap_or_else(|| undocumented.into()) .trim_end() .to_owned(); @@ -75,9 +112,47 @@ fn generate_example(input: &ItemStruct, _args: &[Meta]) -> Result<()> { } } + if let Some(footer) = settings.get("footer") { + file.write_all(footer.as_bytes()) + .expect("written to config file"); + } + Ok(()) } +fn get_settings(args: &[Meta]) -> HashMap { + let mut map = HashMap::new(); + for arg in args { + let Meta::NameValue(MetaNameValue { + path, + value, + .. + }) = arg + else { + continue; + }; + + let Expr::Lit( + ExprLit { + lit: Lit::Str(str), + .. + }, + .., + ) = value + else { + continue; + }; + + let Some(key) = path.segments.iter().next().map(|s| s.ident.clone()) else { + continue; + }; + + map.insert(key.to_string(), str.value()); + } + + map +} + fn get_default(field: &Field) -> Option { for attr in &field.attrs { let Meta::List(MetaList { From c769fcc3471dcc4c976569be57aea65109105f92 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 01:31:30 +0000 Subject: [PATCH 109/245] move core result into core utils Signed-off-by: Jason Volk --- src/core/mod.rs | 4 +--- src/core/utils/mod.rs | 1 + src/core/{ => utils}/result.rs | 0 src/core/{ => utils}/result/debug_inspect.rs | 0 src/core/{ => utils}/result/flat_ok.rs | 0 src/core/{ => utils}/result/inspect_log.rs | 0 src/core/{ => utils}/result/into_is_ok.rs | 0 src/core/{ => utils}/result/log_debug_err.rs | 0 src/core/{ => utils}/result/log_err.rs | 0 src/core/{ => utils}/result/map_expect.rs | 0 src/core/{ => utils}/result/not_found.rs | 0 src/core/{ => utils}/result/unwrap_infallible.rs | 0 12 files changed, 2 insertions(+), 3 deletions(-) rename src/core/{ => utils}/result.rs (100%) rename src/core/{ => utils}/result/debug_inspect.rs (100%) rename src/core/{ => utils}/result/flat_ok.rs (100%) rename src/core/{ => utils}/result/inspect_log.rs (100%) rename src/core/{ => utils}/result/into_is_ok.rs (100%) rename src/core/{ => utils}/result/log_debug_err.rs (100%) rename src/core/{ => utils}/result/log_err.rs (100%) rename src/core/{ => utils}/result/map_expect.rs (100%) rename src/core/{ => utils}/result/not_found.rs (100%) rename src/core/{ => utils}/result/unwrap_infallible.rs (100%) diff --git a/src/core/mod.rs b/src/core/mod.rs index 491d8b4ce..790525549 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -7,7 +7,6 @@ pub mod log; pub mod metrics; pub mod mods; pub mod pdu; -pub mod result; pub mod server; pub mod utils; @@ -19,9 +18,8 @@ pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; pub use pdu::{PduBuilder, PduCount, PduEvent}; -pub use result::Result; pub use server::Server; -pub use utils::{ctor, dtor, implement}; +pub use utils::{ctor, dtor, implement, result, result::Result}; pub use crate as conduit_core; diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 96a98537c..3943a8daa 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -10,6 +10,7 @@ pub mod json; pub mod math; pub mod mutex_map; pub mod rand; +pub mod result; pub mod set; pub mod stream; pub mod string; diff --git a/src/core/result.rs b/src/core/utils/result.rs similarity index 100% rename from src/core/result.rs rename to src/core/utils/result.rs diff --git a/src/core/result/debug_inspect.rs b/src/core/utils/result/debug_inspect.rs similarity index 100% rename from src/core/result/debug_inspect.rs rename to src/core/utils/result/debug_inspect.rs diff --git a/src/core/result/flat_ok.rs b/src/core/utils/result/flat_ok.rs similarity index 100% rename from src/core/result/flat_ok.rs rename to src/core/utils/result/flat_ok.rs diff --git a/src/core/result/inspect_log.rs b/src/core/utils/result/inspect_log.rs similarity index 100% rename from src/core/result/inspect_log.rs rename to src/core/utils/result/inspect_log.rs diff --git a/src/core/result/into_is_ok.rs b/src/core/utils/result/into_is_ok.rs similarity index 100% rename from src/core/result/into_is_ok.rs rename to src/core/utils/result/into_is_ok.rs diff --git a/src/core/result/log_debug_err.rs b/src/core/utils/result/log_debug_err.rs similarity index 100% rename from src/core/result/log_debug_err.rs rename to src/core/utils/result/log_debug_err.rs diff --git a/src/core/result/log_err.rs b/src/core/utils/result/log_err.rs similarity index 100% rename from src/core/result/log_err.rs rename to src/core/utils/result/log_err.rs diff --git a/src/core/result/map_expect.rs b/src/core/utils/result/map_expect.rs similarity index 100% rename from src/core/result/map_expect.rs rename to src/core/utils/result/map_expect.rs diff --git a/src/core/result/not_found.rs b/src/core/utils/result/not_found.rs similarity index 100% rename from src/core/result/not_found.rs rename to src/core/utils/result/not_found.rs diff --git a/src/core/result/unwrap_infallible.rs b/src/core/utils/result/unwrap_infallible.rs similarity index 100% rename from src/core/result/unwrap_infallible.rs rename to src/core/utils/result/unwrap_infallible.rs From aa768b5dec1338a6dbdd1b3a9a00bc6ec9d53090 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 05:03:31 +0000 Subject: [PATCH 110/245] distill active and old keys for federation key/server response Signed-off-by: Jason Volk --- src/api/server/key.rs | 40 +++++++++++++++++++++++++--------- src/service/server_keys/mod.rs | 17 +++++++++++++++ 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/api/server/key.rs b/src/api/server/key.rs index 3913ce43f..5284593d2 100644 --- a/src/api/server/key.rs +++ b/src/api/server/key.rs @@ -1,10 +1,14 @@ -use std::{collections::BTreeMap, time::Duration}; +use std::{ + collections::BTreeMap, + mem::take, + time::{Duration, SystemTime}, +}; use axum::{extract::State, response::IntoResponse, Json}; use conduit::{utils::timepoint_from_now, Result}; use ruma::{ api::{ - federation::discovery::{get_server_keys, ServerSigningKeys}, + federation::discovery::{get_server_keys, OldVerifyKey, ServerSigningKeys}, OutgoingResponse, }, serde::Raw, @@ -21,21 +25,32 @@ use ruma::{ // signature for the response pub(crate) async fn get_server_keys_route(State(services): State) -> Result { let server_name = services.globals.server_name(); - let verify_keys = services.server_keys.verify_keys_for(server_name).await; + let active_key_id = services.server_keys.active_key_id(); + let mut all_keys = services.server_keys.verify_keys_for(server_name).await; + + let verify_keys = all_keys + .remove_entry(active_key_id) + .expect("active verify_key is missing"); + + let old_verify_keys = all_keys + .into_iter() + .map(|(id, key)| (id, OldVerifyKey::new(expires_ts(), key.key))) + .collect(); + let server_key = ServerSigningKeys { - verify_keys, + verify_keys: [verify_keys].into(), + old_verify_keys, server_name: server_name.to_owned(), valid_until_ts: valid_until_ts(), - old_verify_keys: BTreeMap::new(), signatures: BTreeMap::new(), }; - let response = get_server_keys::v2::Response { - server_key: Raw::new(&server_key)?, - } - .try_into_http_response::>()?; + let server_key = Raw::new(&server_key)?; + let mut response = get_server_keys::v2::Response::new(server_key) + .try_into_http_response::>() + .map(|mut response| take(response.body_mut())) + .and_then(|body| serde_json::from_slice(&body).map_err(Into::into))?; - let mut response = serde_json::from_slice(response.body())?; services.server_keys.sign_json(&mut response)?; Ok(Json(response)) @@ -47,6 +62,11 @@ fn valid_until_ts() -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow") } +fn expires_ts() -> MilliSecondsSinceUnixEpoch { + let timepoint = SystemTime::now(); + MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow") +} + /// # `GET /_matrix/key/v2/server/{keyId}` /// /// Gets the public signing keys of this server. diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index dc09703ce..dae45a51c 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -44,7 +44,9 @@ pub type PubKeys = PublicKeySet; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let minimum_valid = Duration::from_secs(3600); + let (keypair, verify_keys) = keypair::init(args.db)?; + debug_assert!(verify_keys.len() == 1, "only one active verify_key supported"); Ok(Arc::new(Self { keypair, @@ -68,6 +70,21 @@ impl crate::Service for Service { #[inline] pub fn keypair(&self) -> &Ed25519KeyPair { &self.keypair } +#[implement(Service)] +#[inline] +pub fn active_key_id(&self) -> &ServerSigningKeyId { self.active_verify_key().0 } + +#[implement(Service)] +#[inline] +pub fn active_verify_key(&self) -> (&ServerSigningKeyId, &VerifyKey) { + debug_assert!(self.verify_keys.len() <= 1, "more than one active verify_key"); + self.verify_keys + .iter() + .next() + .map(|(id, key)| (id.as_ref(), key)) + .expect("missing active verify_key") +} + #[implement(Service)] async fn add_signing_keys(&self, new_keys: ServerSigningKeys) { let origin = &new_keys.server_name; From 89cc865868102697415802e45f4ce19bbaad33d2 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 25 Oct 2024 14:45:22 -0400 Subject: [PATCH 111/245] bump conduwuit to 0.5.0 Signed-off-by: strawberry --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4ac7cc35f..31339b279 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -592,7 +592,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "conduit" -version = "0.4.7" +version = "0.5.0" dependencies = [ "clap", "conduit_admin", @@ -621,7 +621,7 @@ dependencies = [ [[package]] name = "conduit_admin" -version = "0.4.7" +version = "0.5.0" dependencies = [ "clap", "conduit_api", @@ -642,7 +642,7 @@ dependencies = [ [[package]] name = "conduit_api" -version = "0.4.7" +version = "0.5.0" dependencies = [ "axum", "axum-client-ip", @@ -674,7 +674,7 @@ dependencies = [ [[package]] name = "conduit_core" -version = "0.4.7" +version = "0.5.0" dependencies = [ "argon2", "arrayvec", @@ -725,7 +725,7 @@ dependencies = [ [[package]] name = "conduit_database" -version = "0.4.7" +version = "0.5.0" dependencies = [ "arrayvec", "conduit_core", @@ -741,7 +741,7 @@ dependencies = [ [[package]] name = "conduit_macros" -version = "0.4.7" +version = "0.5.0" dependencies = [ "itertools 0.13.0", "proc-macro2", @@ -751,7 +751,7 @@ dependencies = [ [[package]] name = "conduit_router" -version = "0.4.7" +version = "0.5.0" dependencies = [ "axum", "axum-client-ip", @@ -784,7 +784,7 @@ dependencies = [ [[package]] name = "conduit_service" -version = "0.4.7" +version = "0.5.0" dependencies = [ "async-trait", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index 966c28183..64cd8ba37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/girlbossceo/conduwuit" rust-version = "1.82.0" -version = "0.4.7" +version = "0.5.0" [workspace.metadata.crane] name = "conduit" From f29879288d00e24ec04d6a42bab6ef91e8bafda7 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 25 Oct 2024 20:47:30 -0400 Subject: [PATCH 112/245] document conduwuit k8s helm chart Signed-off-by: strawberry --- docs/deploying/kubernetes.md | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 docs/deploying/kubernetes.md diff --git a/docs/deploying/kubernetes.md b/docs/deploying/kubernetes.md new file mode 100644 index 000000000..2a1bcb51a --- /dev/null +++ b/docs/deploying/kubernetes.md @@ -0,0 +1,4 @@ +# conduwuit for Kubernetes + +conduwuit doesn't support horizontal scalability or distributed loading natively, however a community maintained Helm Chart is available here to run conduwuit on Kubernetes: + From 652b04b9b6bc30f55b286645bb8cd706d429056c Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 25 Oct 2024 20:48:14 -0400 Subject: [PATCH 113/245] update conduwuit freebsd docs Signed-off-by: strawberry --- docs/deploying/freebsd.md | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/deploying/freebsd.md b/docs/deploying/freebsd.md index 4ac83515b..65b40204b 100644 --- a/docs/deploying/freebsd.md +++ b/docs/deploying/freebsd.md @@ -1,11 +1,5 @@ # conduwuit for FreeBSD -conduwuit at the moment does not provide FreeBSD builds. Building conduwuit on -FreeBSD requires a specific environment variable to use the system prebuilt -RocksDB library instead of rust-rocksdb / rust-librocksdb-sys which does *not* -work and will cause a build error or coredump. +conduwuit at the moment does not provide FreeBSD builds or have FreeBSD packaging, however conduwuit does build and work on FreeBSD using the system-provided RocksDB. -Use the following environment variable: `ROCKSDB_LIB_DIR=/usr/local/lib` - -Such example commandline with it can be: `ROCKSDB_LIB_DIR=/usr/local/lib cargo -build --release` +Contributions for getting conduwuit packaged are welcome. From 2ce91f33afbd08a722684f6d0e3928cc9a497696 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 25 Oct 2024 21:08:34 -0400 Subject: [PATCH 114/245] log method on tracing req spans, fix path sometimes being truncated Signed-off-by: strawberry --- src/api/router/auth.rs | 4 +--- src/router/layers.rs | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 6b90c5ff9..28d6bc551 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -5,7 +5,6 @@ use axum_extra::{ TypedHeader, }; use conduit::{debug_error, err, warn, Err, Error, Result}; -use http::uri::PathAndQuery; use ruma::{ api::{client::error::ErrorKind, AuthScheme, Metadata}, server_util::authorization::XMatrix, @@ -190,12 +189,11 @@ async fn auth_server(services: &Services, request: &mut Request, body: Option<&C let destination = services.globals.server_name(); let origin = &x_matrix.origin; - #[allow(clippy::or_fun_call)] let signature_uri = request .parts .uri .path_and_query() - .unwrap_or(&PathAndQuery::from_static("/")) + .expect("all requests have a path") .to_string(); let signature: [Member; 1] = [(x_matrix.key.to_string(), Value::String(x_matrix.sig.to_string()))]; diff --git a/src/router/layers.rs b/src/router/layers.rs index a1a70bb86..908105d85 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -184,12 +184,20 @@ fn catch_panic(err: Box) -> http::Response(request: &http::Request) -> tracing::Span { - let path = request - .extensions() - .get::() - .map_or_else(|| request.uri().path(), truncated_matched_path); - - tracing::info_span!("router:", %path) + let path = request.extensions().get::().map_or_else( + || { + request + .uri() + .path_and_query() + .expect("all requests have a path") + .as_str() + }, + truncated_matched_path, + ); + + let method = request.method(); + + tracing::info_span!("router:", %method, %path) } fn truncated_matched_path(path: &MatchedPath) -> &str { From 0efe24a028f5954e9aa4969f533ed89a51115bbc Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 25 Oct 2024 21:13:14 -0400 Subject: [PATCH 115/245] remove spaces from CSP header to save a few bytes Signed-off-by: strawberry --- src/router/layers.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/router/layers.rs b/src/router/layers.rs index 908105d85..fd68cc367 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -24,15 +24,15 @@ use tracing::Level; use crate::{request, router}; -const CONDUWUIT_CSP: &[&str] = &[ - "sandbox", +const CONDUWUIT_CSP: &[&str; 5] = &[ "default-src 'none'", "frame-ancestors 'none'", "form-action 'none'", "base-uri 'none'", + "sandbox", ]; -const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"]; +const CONDUWUIT_PERMISSIONS_POLICY: &[&str; 2] = &["interest-cohort=()", "browsing-topics=()"]; pub(crate) fn build(services: &Arc) -> Result<(Router, Guard)> { let server = &services.server; @@ -78,7 +78,7 @@ pub(crate) fn build(services: &Arc) -> Result<(Router, Guard)> { )) .layer(SetResponseHeaderLayer::if_not_present( header::CONTENT_SECURITY_POLICY, - HeaderValue::from_str(&CONDUWUIT_CSP.join("; "))?, + HeaderValue::from_str(&CONDUWUIT_CSP.join(";"))?, )) .layer(cors_layer(server)) .layer(body_limit_layer(server)) From d6991611f0d79d1ad4a1e3cdb5d1372a79b87ac7 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 26 Oct 2024 12:32:47 -0400 Subject: [PATCH 116/245] add `require_auth_for_profile_requests` config option, check endpoint metadata instead of request string Signed-off-by: strawberry --- src/api/router/auth.rs | 36 ++++++++++++++++++++++++++++++------ src/core/config/mod.rs | 11 ++++++++++- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 28d6bc551..6b1bb1a9f 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -6,7 +6,15 @@ use axum_extra::{ }; use conduit::{debug_error, err, warn, Err, Error, Result}; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, Metadata}, + api::{ + client::{ + directory::get_public_rooms, + error::ErrorKind, + profile::{get_avatar_url, get_display_name, get_profile, get_profile_key, get_timezone_key}, + voip::get_turn_server_info, + }, + AuthScheme, IncomingRequest, Metadata, + }, server_util::authorization::XMatrix, CanonicalJsonObject, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, }; @@ -54,9 +62,8 @@ pub(super) async fn auth( }; if metadata.authentication == AuthScheme::None { - match request.parts.uri.path() { - // TODO: can we check this better? - "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { + match metadata { + &get_public_rooms::v3::Request::METADATA => { if !services .globals .config @@ -73,6 +80,23 @@ pub(super) async fn auth( } } }, + &get_profile::v3::Request::METADATA + | &get_profile_key::unstable::Request::METADATA + | &get_display_name::v3::Request::METADATA + | &get_avatar_url::v3::Request::METADATA + | &get_timezone_key::unstable::Request::METADATA => { + if services.globals.config.require_auth_for_profile_requests { + match token { + Token::Appservice(_) | Token::User(_) => { + // we should have validated the token above + // already + }, + Token::None | Token::Invalid => { + return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing or invalid access token.")); + }, + } + } + }, _ => {}, }; } @@ -107,9 +131,9 @@ pub(super) async fn auth( appservice_info: Some(*info), }) }, - (AuthScheme::AccessToken, Token::None) => match request.parts.uri.path() { + (AuthScheme::AccessToken, Token::None) => match metadata { // TODO: can we check this better? - "/_matrix/client/v3/voip/turnServer" | "/_matrix/client/r0/voip/turnServer" => { + &get_turn_server_info::v3::Request::METADATA => { if services.globals.config.turn_allow_guests { Ok(Auth { origin: None, diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index ff2144200..04e44fd76 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -58,7 +58,6 @@ pub struct Config { /// YOU NEED TO EDIT THIS pub server_name: OwnedServerName, - /// Database backend: Only rocksdb is supported. /// default address (IPv4 or IPv6) conduwuit will listen on. Generally you /// want this to be localhost (127.0.0.1 / ::1). If you are using Docker or /// a container NAT networking setup, you likely need this to be 0.0.0.0. @@ -94,6 +93,8 @@ pub struct Config { #[serde(default = "default_unix_socket_perms")] pub unix_socket_perms: u32, + /// Database backend: Only rocksdb is supported. + /// /// default: rocksdb #[serde(default = "default_database_backend")] pub database_backend: String, @@ -406,6 +407,14 @@ pub struct Config { #[serde(default)] pub federation_loopback: bool, + /// Set this to true to require authentication on the normally + /// unauthenticated profile retrieval endpoints (GET) + /// "/_matrix/client/v3/profile/{userId}". + /// + /// This can prevent profile scraping. + #[serde(default)] + pub require_auth_for_profile_requests: bool, + /// Set this to true to allow your server's public room directory to be /// federated. Set this to false to protect against /publicRooms spiders, /// but will forbid external users from viewing your server's public room From 60d84195c51c523b965c17d75ebca861290260e5 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 26 Oct 2024 17:26:50 -0400 Subject: [PATCH 117/245] implement MSC4210, bump ruwuma Signed-off-by: strawberry --- Cargo.lock | 26 +++---- Cargo.toml | 3 +- src/admin/debug/commands.rs | 1 + src/api/client/push.rs | 100 +++++++++++++------------ src/service/rooms/event_handler/mod.rs | 1 + 5 files changed, 69 insertions(+), 62 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 31339b279..c64d3cc67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2976,7 +2976,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "assign", "js_int", @@ -2998,7 +2998,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "js_int", "ruma-common", @@ -3010,7 +3010,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "as_variant", "assign", @@ -3033,7 +3033,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "as_variant", "base64 0.22.1", @@ -3063,7 +3063,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3087,7 +3087,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "bytes", "http", @@ -3105,7 +3105,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "js_int", "thiserror", @@ -3114,7 +3114,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "js_int", "ruma-common", @@ -3124,7 +3124,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "cfg-if", "once_cell", @@ -3140,7 +3140,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "js_int", "ruma-common", @@ -3152,7 +3152,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "headers", "http", @@ -3165,7 +3165,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3181,7 +3181,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73#d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" dependencies = [ "futures-util", "itertools 0.13.0", diff --git a/Cargo.toml b/Cargo.toml index 64cd8ba37..73f16daf0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -315,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "d7baeb7e5c3ae28e79ad3fe81c5e8b207a26cc73" +rev = "39c1addd37a4eed612ac1135edc2cccd9d331d5e" features = [ "compat", "rand", @@ -346,6 +346,7 @@ features = [ "unstable-msc4121", "unstable-msc4125", "unstable-msc4186", + "unstable-msc4210", # remove legacy mentions "unstable-extensible-events", ] diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 7fe8addfa..0fd3c91bf 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -203,6 +203,7 @@ pub(super) async fn get_remote_pdu( &server, ruma::api::federation::event::get_event::v1::Request { event_id: event_id.clone().into(), + include_unredacted_content: None, }, ) .await diff --git a/src/api/client/push.rs b/src/api/client/push.rs index 103c0c5e1..de280b32f 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -1,18 +1,18 @@ use axum::extract::State; -use conduit::err; +use conduit::{err, Err}; use ruma::{ api::client::{ error::ErrorKind, push::{ delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all, - set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleScope, + set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, }, }, events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, GlobalAccountDataEventType, }, - push::{InsertPushRuleError, RemovePushRuleError, Ruleset}, + push::{InsertPushRuleError, PredefinedContentRuleId, PredefinedOverrideRuleId, RemovePushRuleError, Ruleset}, CanonicalJsonObject, CanonicalJsonValue, }; use service::Services; @@ -43,7 +43,24 @@ pub(crate) async fn get_pushrules_all_route( let account_data_content = serde_json::from_value::(content_value.into()) .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - let global_ruleset: Ruleset = account_data_content.global; + let mut global_ruleset = account_data_content.global; + + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + { + use ruma::push::RuleKind::*; + + global_ruleset + .remove(Override, PredefinedOverrideRuleId::ContainsDisplayName) + .ok(); + global_ruleset + .remove(Override, PredefinedOverrideRuleId::RoomNotif) + .ok(); + + global_ruleset + .remove(Content, PredefinedContentRuleId::ContainsUserName) + .ok(); + }; Ok(get_pushrules_all::v3::Response { global: global_ruleset, @@ -58,6 +75,15 @@ pub(crate) async fn get_pushrule_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + if body.rule_id.as_str() == PredefinedContentRuleId::ContainsUserName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::ContainsDisplayName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::RoomNotif.as_str() + { + return Err!(Request(NotFound("Push rule not found."))); + } + let event: PushRulesEvent = services .account_data .get_global(sender_user, GlobalAccountDataEventType::PushRules) @@ -79,7 +105,7 @@ pub(crate) async fn get_pushrule_route( } } -/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` +/// # `PUT /_matrix/client/r0/pushrules/global/{kind}/{ruleId}` /// /// Creates a single specified push rule for this user. pub(crate) async fn set_pushrule_route( @@ -88,13 +114,6 @@ pub(crate) async fn set_pushrule_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - let mut account_data: PushRulesEvent = services .account_data .get_global(sender_user, GlobalAccountDataEventType::PushRules) @@ -145,7 +164,7 @@ pub(crate) async fn set_pushrule_route( Ok(set_pushrule::v3::Response {}) } -/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` +/// # `GET /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/actions` /// /// Gets the actions of a single specified push rule for this user. pub(crate) async fn get_pushrule_actions_route( @@ -153,11 +172,13 @@ pub(crate) async fn get_pushrule_actions_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + if body.rule_id.as_str() == PredefinedContentRuleId::ContainsUserName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::ContainsDisplayName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::RoomNotif.as_str() + { + return Err!(Request(NotFound("Push rule not found."))); } let event: PushRulesEvent = services @@ -178,7 +199,7 @@ pub(crate) async fn get_pushrule_actions_route( }) } -/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` +/// # `PUT /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/actions` /// /// Sets the actions of a single specified push rule for this user. pub(crate) async fn set_pushrule_actions_route( @@ -186,13 +207,6 @@ pub(crate) async fn set_pushrule_actions_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - let mut account_data: PushRulesEvent = services .account_data .get_global(sender_user, GlobalAccountDataEventType::PushRules) @@ -221,7 +235,7 @@ pub(crate) async fn set_pushrule_actions_route( Ok(set_pushrule_actions::v3::Response {}) } -/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` +/// # `GET /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/enabled` /// /// Gets the enabled status of a single specified push rule for this user. pub(crate) async fn get_pushrule_enabled_route( @@ -229,11 +243,15 @@ pub(crate) async fn get_pushrule_enabled_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + if body.rule_id.as_str() == PredefinedContentRuleId::ContainsUserName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::ContainsDisplayName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::RoomNotif.as_str() + { + return Ok(get_pushrule_enabled::v3::Response { + enabled: false, + }); } let event: PushRulesEvent = services @@ -254,7 +272,7 @@ pub(crate) async fn get_pushrule_enabled_route( }) } -/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` +/// # `PUT /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/enabled` /// /// Sets the enabled status of a single specified push rule for this user. pub(crate) async fn set_pushrule_enabled_route( @@ -262,13 +280,6 @@ pub(crate) async fn set_pushrule_enabled_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - let mut account_data: PushRulesEvent = services .account_data .get_global(sender_user, GlobalAccountDataEventType::PushRules) @@ -297,7 +308,7 @@ pub(crate) async fn set_pushrule_enabled_route( Ok(set_pushrule_enabled::v3::Response {}) } -/// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` +/// # `DELETE /_matrix/client/r0/pushrules/global/{kind}/{ruleId}` /// /// Deletes a single specified push rule for this user. pub(crate) async fn delete_pushrule_route( @@ -305,13 +316,6 @@ pub(crate) async fn delete_pushrule_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - let mut account_data: PushRulesEvent = services .account_data .get_global(sender_user, GlobalAccountDataEventType::PushRules) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 0b2bbf731..026c5a4c0 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1159,6 +1159,7 @@ impl Service { origin, get_event::v1::Request { event_id: (*next_id).to_owned(), + include_unredacted_content: None, }, ) .await From b921983a795f042dac0f348f1a832c73bd44de7f Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 26 Oct 2024 17:39:27 -0400 Subject: [PATCH 118/245] send room alias on pusher notification Signed-off-by: strawberry --- src/service/pusher/mod.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index af15e332d..2b90319e9 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -332,6 +332,13 @@ impl Service { .await .ok(); + notifi.room_alias = self + .services + .state_accessor + .get_canonical_alias(&event.room_id) + .await + .ok(); + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) .await?; } From 49343281d477cf414cfec737a00c150d5db34ba3 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 06:01:53 +0000 Subject: [PATCH 119/245] additional bool extensions Signed-off-by: Jason Volk --- src/core/utils/bool.rs | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/core/utils/bool.rs b/src/core/utils/bool.rs index d7ce78fe3..d5fa85aa8 100644 --- a/src/core/utils/bool.rs +++ b/src/core/utils/bool.rs @@ -2,12 +2,49 @@ /// Boolean extensions and chain.starters pub trait BoolExt { + fn map T>(self, f: F) -> T + where + Self: Sized; + + fn map_ok_or T>(self, err: E, f: F) -> Result; + + fn map_or T>(self, err: T, f: F) -> T; + + fn map_or_else T>(self, err: F, f: F) -> T; + + fn ok_or(self, err: E) -> Result<(), E>; + + fn ok_or_else E>(self, err: F) -> Result<(), E>; + fn or T>(self, f: F) -> Option; fn or_some(self, t: T) -> Option; } impl BoolExt for bool { + #[inline] + fn map T>(self, f: F) -> T + where + Self: Sized, + { + f(self) + } + + #[inline] + fn map_ok_or T>(self, err: E, f: F) -> Result { self.ok_or(err).map(|()| f()) } + + #[inline] + fn map_or T>(self, err: T, f: F) -> T { self.then(f).unwrap_or(err) } + + #[inline] + fn map_or_else T>(self, err: F, f: F) -> T { self.then(f).unwrap_or_else(err) } + + #[inline] + fn ok_or(self, err: E) -> Result<(), E> { self.then_some(()).ok_or(err) } + + #[inline] + fn ok_or_else E>(self, err: F) -> Result<(), E> { self.then_some(()).ok_or_else(err) } + #[inline] fn or T>(self, f: F) -> Option { (!self).then(f) } From efb28c1a9944840143af37bff65bd475f38df717 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 06:03:45 +0000 Subject: [PATCH 120/245] add a Map::contains suite to db Signed-off-by: Jason Volk --- src/database/map.rs | 1 + src/database/map/contains.rs | 88 ++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 src/database/map/contains.rs diff --git a/src/database/map.rs b/src/database/map.rs index cac20d6a6..d6b8bf38c 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -1,3 +1,4 @@ +mod contains; mod count; mod get; mod insert; diff --git a/src/database/map/contains.rs b/src/database/map/contains.rs new file mode 100644 index 000000000..a98fe7c53 --- /dev/null +++ b/src/database/map/contains.rs @@ -0,0 +1,88 @@ +use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; + +use arrayvec::ArrayVec; +use conduit::{implement, utils::TryFutureExtExt, Err, Result}; +use futures::future::ready; +use serde::Serialize; + +use crate::{ser, util}; + +/// Returns true if the map contains the key. +/// - key is serialized into allocated buffer +/// - harder errors may not be reported +#[implement(super::Map)] +pub fn contains(&self, key: &K) -> impl Future + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = Vec::::with_capacity(64); + self.bcontains(key, &mut buf) +} + +/// Returns true if the map contains the key. +/// - key is serialized into stack-buffer +/// - harder errors will panic +#[implement(super::Map)] +pub fn acontains(&self, key: &K) -> impl Future + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = ArrayVec::::new(); + self.bcontains(key, &mut buf) +} + +/// Returns true if the map contains the key. +/// - key is serialized into provided buffer +/// - harder errors will panic +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bcontains(&self, key: &K, buf: &mut B) -> impl Future + Send +where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize query key"); + self.exists(key).is_ok() +} + +/// Returns Ok if the map contains the key. +/// - key is raw +#[implement(super::Map)] +pub fn exists(&self, key: &K) -> impl Future> + Send +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + ready(self.exists_blocking(key)) +} + +/// Returns Ok if the map contains the key; NotFound otherwise. Harder errors +/// may not always be reported properly. +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn exists_blocking(&self, key: &K) -> Result<()> +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + if self.maybe_exists_blocking(key) + && self + .db + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options) + .map_err(util::map_err)? + .is_some() + { + Ok(()) + } else { + Err!(Request(NotFound("Not found in database"))) + } +} + +#[implement(super::Map)] +fn maybe_exists_blocking(&self, key: &K) -> bool +where + K: AsRef<[u8]> + ?Sized, +{ + self.db + .db + .key_may_exist_cf_opt(&self.cf(), key, &self.read_options) +} From 9438dc89e612ada5e5e44b48315877055498313b Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 10:58:08 +0000 Subject: [PATCH 121/245] merge and resplit/cleanup appservice service Signed-off-by: Jason Volk --- src/admin/query/appservice.rs | 6 +- src/service/appservice/data.rs | 50 ------ src/service/appservice/mod.rs | 173 ++++++-------------- src/service/appservice/namespace_regex.rs | 70 ++++++++ src/service/appservice/registration_info.rs | 39 +++++ 5 files changed, 157 insertions(+), 181 deletions(-) delete mode 100644 src/service/appservice/data.rs create mode 100644 src/service/appservice/namespace_regex.rs create mode 100644 src/service/appservice/registration_info.rs diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 4b97ef4eb..02e89e7a1 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -26,11 +26,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> appservice_id, } => { let timer = tokio::time::Instant::now(); - let results = services - .appservice - .db - .get_registration(appservice_id.as_ref()) - .await; + let results = services.appservice.get_registration(&appservice_id).await; let query_time = timer.elapsed(); diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs deleted file mode 100644 index 8fb7d9582..000000000 --- a/src/service/appservice/data.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::sync::Arc; - -use conduit::{err, utils::stream::TryIgnore, Result}; -use database::{Database, Map}; -use futures::Stream; -use ruma::api::appservice::Registration; - -pub struct Data { - id_appserviceregistrations: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - id_appserviceregistrations: db["id_appserviceregistrations"].clone(), - } - } - - /// Registers an appservice and returns the ID to the caller - pub(super) fn register_appservice(&self, yaml: &Registration) -> Result { - let id = yaml.id.as_str(); - self.id_appserviceregistrations - .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes()); - - Ok(id.to_owned()) - } - - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.id_appserviceregistrations - .remove(service_name.as_bytes()); - Ok(()) - } - - pub async fn get_registration(&self, id: &str) -> Result { - self.id_appserviceregistrations - .get(id) - .await - .and_then(|ref bytes| serde_yaml::from_slice(bytes).map_err(Into::into)) - .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) - } - - pub(super) fn iter_ids(&self) -> impl Stream + Send + '_ { - self.id_appserviceregistrations.keys().ignore_err() - } -} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 7e2dc7387..1617e6e6e 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,147 +1,49 @@ -mod data; +mod namespace_regex; +mod registration_info; use std::{collections::BTreeMap, sync::Arc}; use async_trait::async_trait; -use conduit::{err, Result}; -use data::Data; +use conduit::{err, utils::stream::TryIgnore, Result}; +use database::Map; use futures::{Future, StreamExt, TryStreamExt}; -use regex::RegexSet; -use ruma::{ - api::appservice::{Namespace, Registration}, - RoomAliasId, RoomId, UserId, -}; +use ruma::{api::appservice::Registration, RoomAliasId, RoomId, UserId}; use tokio::sync::RwLock; +pub use self::{namespace_regex::NamespaceRegex, registration_info::RegistrationInfo}; use crate::{sending, Dep}; -/// Compiled regular expressions for a namespace -#[derive(Clone, Debug)] -pub struct NamespaceRegex { - pub exclusive: Option, - pub non_exclusive: Option, -} - -impl NamespaceRegex { - /// Checks if this namespace has rights to a namespace - #[inline] - #[must_use] - pub fn is_match(&self, heystack: &str) -> bool { - if self.is_exclusive_match(heystack) { - return true; - } - - if let Some(non_exclusive) = &self.non_exclusive { - if non_exclusive.is_match(heystack) { - return true; - } - } - false - } - - /// Checks if this namespace has exlusive rights to a namespace - #[inline] - #[must_use] - pub fn is_exclusive_match(&self, heystack: &str) -> bool { - if let Some(exclusive) = &self.exclusive { - if exclusive.is_match(heystack) { - return true; - } - } - false - } -} - -impl RegistrationInfo { - #[must_use] - pub fn is_user_match(&self, user_id: &UserId) -> bool { - self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() - } - - #[inline] - #[must_use] - pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { - self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() - } -} - -impl TryFrom> for NamespaceRegex { - type Error = regex::Error; - - fn try_from(value: Vec) -> Result { - let mut exclusive = Vec::with_capacity(value.len()); - let mut non_exclusive = Vec::with_capacity(value.len()); - - for namespace in value { - if namespace.exclusive { - exclusive.push(namespace.regex); - } else { - non_exclusive.push(namespace.regex); - } - } - - Ok(Self { - exclusive: if exclusive.is_empty() { - None - } else { - Some(RegexSet::new(exclusive)?) - }, - non_exclusive: if non_exclusive.is_empty() { - None - } else { - Some(RegexSet::new(non_exclusive)?) - }, - }) - } -} - -/// Appservice registration combined with its compiled regular expressions. -#[derive(Clone, Debug)] -pub struct RegistrationInfo { - pub registration: Registration, - pub users: NamespaceRegex, - pub aliases: NamespaceRegex, - pub rooms: NamespaceRegex, -} - -impl TryFrom for RegistrationInfo { - type Error = regex::Error; - - fn try_from(value: Registration) -> Result { - Ok(Self { - users: value.namespaces.users.clone().try_into()?, - aliases: value.namespaces.aliases.clone().try_into()?, - rooms: value.namespaces.rooms.clone().try_into()?, - registration: value, - }) - } -} - pub struct Service { - pub db: Data, - services: Services, registration_info: RwLock>, + services: Services, + db: Data, } struct Services { sending: Dep, } +struct Data { + id_appserviceregistrations: Arc, +} + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + registration_info: RwLock::new(BTreeMap::new()), services: Services { sending: args.depend::("sending"), }, - registration_info: RwLock::new(BTreeMap::new()), + db: Data { + id_appserviceregistrations: args.db["id_appserviceregistrations"].clone(), + }, })) } async fn worker(self: Arc) -> Result<()> { // Inserting registrations into cache - for appservice in iter_ids(&self.db).await? { + for appservice in self.iter_db_ids().await? { self.registration_info.write().await.insert( appservice.0, appservice @@ -158,9 +60,6 @@ impl crate::Service for Service { } impl Service { - #[inline] - pub async fn all(&self) -> Result> { iter_ids(&self.db).await } - /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result { //TODO: Check for collisions between exclusive appservice namespaces @@ -169,7 +68,11 @@ impl Service { .await .insert(yaml.id.clone(), yaml.clone().try_into()?); - self.db.register_appservice(&yaml) + let id = yaml.id.as_str(); + let yaml = serde_yaml::to_string(&yaml)?; + self.db.id_appserviceregistrations.insert(id, yaml); + + Ok(id.to_owned()) } /// Remove an appservice registration @@ -186,7 +89,7 @@ impl Service { .ok_or(err!("Appservice not found"))?; // remove the appservice from the database - self.db.unregister_appservice(service_name)?; + self.db.id_appserviceregistrations.remove(service_name); // deletes all active requests for the appservice if there are any so we stop // sending to the URL @@ -254,11 +157,29 @@ impl Service { pub fn read(&self) -> impl Future>> { self.registration_info.read() } -} -async fn iter_ids(db: &Data) -> Result> { - db.iter_ids() - .then(|id| async move { Ok((id.clone(), db.get_registration(&id).await?)) }) - .try_collect() - .await + #[inline] + pub async fn all(&self) -> Result> { self.iter_db_ids().await } + + pub async fn get_db_registration(&self, id: &str) -> Result { + self.db + .id_appserviceregistrations + .get(id) + .await + .and_then(|ref bytes| serde_yaml::from_slice(bytes).map_err(Into::into)) + .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) + } + + async fn iter_db_ids(&self) -> Result> { + self.db + .id_appserviceregistrations + .keys() + .ignore_err() + .then(|id: String| async move { + let reg = self.get_db_registration(&id).await?; + Ok((id, reg)) + }) + .try_collect() + .await + } } diff --git a/src/service/appservice/namespace_regex.rs b/src/service/appservice/namespace_regex.rs new file mode 100644 index 000000000..3529fc0ef --- /dev/null +++ b/src/service/appservice/namespace_regex.rs @@ -0,0 +1,70 @@ +use conduit::Result; +use regex::RegexSet; +use ruma::api::appservice::Namespace; + +/// Compiled regular expressions for a namespace +#[derive(Clone, Debug)] +pub struct NamespaceRegex { + pub exclusive: Option, + pub non_exclusive: Option, +} + +impl NamespaceRegex { + /// Checks if this namespace has rights to a namespace + #[inline] + #[must_use] + pub fn is_match(&self, heystack: &str) -> bool { + if self.is_exclusive_match(heystack) { + return true; + } + + if let Some(non_exclusive) = &self.non_exclusive { + if non_exclusive.is_match(heystack) { + return true; + } + } + false + } + + /// Checks if this namespace has exlusive rights to a namespace + #[inline] + #[must_use] + pub fn is_exclusive_match(&self, heystack: &str) -> bool { + if let Some(exclusive) = &self.exclusive { + if exclusive.is_match(heystack) { + return true; + } + } + false + } +} + +impl TryFrom> for NamespaceRegex { + type Error = regex::Error; + + fn try_from(value: Vec) -> Result { + let mut exclusive = Vec::with_capacity(value.len()); + let mut non_exclusive = Vec::with_capacity(value.len()); + + for namespace in value { + if namespace.exclusive { + exclusive.push(namespace.regex); + } else { + non_exclusive.push(namespace.regex); + } + } + + Ok(Self { + exclusive: if exclusive.is_empty() { + None + } else { + Some(RegexSet::new(exclusive)?) + }, + non_exclusive: if non_exclusive.is_empty() { + None + } else { + Some(RegexSet::new(non_exclusive)?) + }, + }) + } +} diff --git a/src/service/appservice/registration_info.rs b/src/service/appservice/registration_info.rs new file mode 100644 index 000000000..2c8595b1b --- /dev/null +++ b/src/service/appservice/registration_info.rs @@ -0,0 +1,39 @@ +use conduit::Result; +use ruma::{api::appservice::Registration, UserId}; + +use super::NamespaceRegex; + +/// Appservice registration combined with its compiled regular expressions. +#[derive(Clone, Debug)] +pub struct RegistrationInfo { + pub registration: Registration, + pub users: NamespaceRegex, + pub aliases: NamespaceRegex, + pub rooms: NamespaceRegex, +} + +impl RegistrationInfo { + #[must_use] + pub fn is_user_match(&self, user_id: &UserId) -> bool { + self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() + } + + #[inline] + #[must_use] + pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { + self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() + } +} + +impl TryFrom for RegistrationInfo { + type Error = regex::Error; + + fn try_from(value: Registration) -> Result { + Ok(Self { + users: value.namespaces.users.clone().try_into()?, + aliases: value.namespaces.aliases.clone().try_into()?, + rooms: value.namespaces.rooms.clone().try_into()?, + registration: value, + }) + } +} From 0e616f1d1267481ed97e9adc6d779d29fcf9ade2 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 11:16:44 +0000 Subject: [PATCH 122/245] add event macro log wrapper suite Signed-off-by: Jason Volk --- src/core/log/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/core/log/mod.rs b/src/core/log/mod.rs index 04d250a6d..1cba236f0 100644 --- a/src/core/log/mod.rs +++ b/src/core/log/mod.rs @@ -27,6 +27,11 @@ pub struct Log { // necessary but discouraged. Remember debug_ log macros are also exported to // the crate namespace like these. +#[macro_export] +macro_rules! event { + ( $level:expr, $($x:tt)+ ) => { ::tracing::event!( $level, $($x)+ ) } +} + #[macro_export] macro_rules! error { ( $($x:tt)+ ) => { ::tracing::error!( $($x)+ ) } From e175b7d28dffbea663c25e66babd3184e7fc5b1f Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 11:24:03 +0000 Subject: [PATCH 123/245] slightly cleanup prev_event eval loop Signed-off-by: Jason Volk --- src/service/rooms/event_handler/mod.rs | 41 +++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 026c5a4c0..ec04e748e 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -206,7 +206,7 @@ impl Service { debug!(events = ?sorted_prev_events, "Got previous events"); for prev_id in sorted_prev_events { self.services.server.check_running()?; - match self + if let Err(e) = self .handle_prev_pdu( origin, event_id, @@ -218,25 +218,26 @@ impl Service { ) .await { - Ok(()) => continue, - Err(e) => { - warn!("Prev event {prev_id} failed: {e}"); - match self - .services - .globals - .bad_event_ratelimiter - .write() - .expect("locked") - .entry((*prev_id).to_owned()) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)); - }, - }; - }, + use hash_map::Entry; + + let now = Instant::now(); + warn!("Prev event {prev_id} failed: {e}"); + + match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(prev_id.into()) + { + Entry::Vacant(e) => { + e.insert((now, 1)); + }, + Entry::Occupied(mut e) => { + *e.get_mut() = (now, e.get().1.saturating_add(1)); + }, + }; } } From 60cc07134f3d80f0ba25d4bc1b6736c30494f947 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 11:24:37 +0000 Subject: [PATCH 124/245] log error for auth_chain corruption immediately Signed-off-by: Jason Volk --- src/service/rooms/auth_chain/mod.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index f3861ca3f..1387bc7d7 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -167,10 +167,12 @@ impl Service { Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"), Ok(pdu) => { if pdu.room_id != room_id { - return Err!(Request(Forbidden( - "auth event {event_id:?} for incorrect room {} which is not {room_id}", - pdu.room_id, - ))); + return Err!(Request(Forbidden(error!( + ?event_id, + ?room_id, + wrong_room_id = ?pdu.room_id, + "auth event for incorrect room" + )))); } for auth_event in &pdu.auth_events { From ee92a33a4de8db924ee5e203f5b3c64dade8dcc6 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 24 Oct 2024 12:03:56 +0000 Subject: [PATCH 125/245] add some accessors to Ar for common patterns Signed-off-by: Jason Volk --- src/api/client/backup.rs | 107 ++++++++++++--------------- src/api/router/args.rs | 18 +++-- src/api/router/handler.rs | 2 +- src/api/server/backfill.rs | 7 +- src/api/server/event.rs | 6 +- src/api/server/event_auth.rs | 6 +- src/api/server/get_missing_events.rs | 8 +- src/api/server/hierarchy.rs | 4 +- src/api/server/invite.rs | 9 +-- src/api/server/make_join.rs | 13 ++-- src/api/server/make_leave.rs | 5 +- src/api/server/send.rs | 12 ++- src/api/server/send_join.rs | 19 ++--- src/api/server/send_leave.rs | 22 +----- src/api/server/state.rs | 6 +- src/api/server/state_ids.rs | 6 +- src/api/server/user.rs | 6 +- 17 files changed, 109 insertions(+), 147 deletions(-) diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index d52da80a2..f435e0869 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -18,10 +18,9 @@ use crate::{Result, Ruma}; pub(crate) async fn create_backup_version_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let version = services .key_backups - .create_backup(sender_user, &body.algorithm)?; + .create_backup(body.sender_user(), &body.algorithm)?; Ok(create_backup_version::v3::Response { version, @@ -35,10 +34,9 @@ pub(crate) async fn create_backup_version_route( pub(crate) async fn update_backup_version_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); services .key_backups - .update_backup(sender_user, &body.version, &body.algorithm) + .update_backup(body.sender_user(), &body.version, &body.algorithm) .await?; Ok(update_backup_version::v3::Response {}) @@ -50,19 +48,25 @@ pub(crate) async fn update_backup_version_route( pub(crate) async fn get_latest_backup_info_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let (version, algorithm) = services .key_backups - .get_latest_backup(sender_user) + .get_latest_backup(body.sender_user()) .await .map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?; Ok(get_latest_backup_info::v3::Response { algorithm, - count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version).await) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &version).await, + count: (UInt::try_from( + services + .key_backups + .count_keys(body.sender_user(), &version) + .await, + ) + .expect("user backup keys count should not be that high")), + etag: services + .key_backups + .get_etag(body.sender_user(), &version) + .await, version, }) } @@ -73,10 +77,9 @@ pub(crate) async fn get_latest_backup_info_route( pub(crate) async fn get_backup_info_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let algorithm = services .key_backups - .get_backup(sender_user, &body.version) + .get_backup(body.sender_user(), &body.version) .await .map_err(|_| err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))))?; @@ -84,12 +87,12 @@ pub(crate) async fn get_backup_info_route( algorithm, count: services .key_backups - .count_keys(sender_user, &body.version) + .count_keys(body.sender_user(), &body.version) .await .try_into()?, etag: services .key_backups - .get_etag(sender_user, &body.version) + .get_etag(body.sender_user(), &body.version) .await, version: body.version.clone(), }) @@ -104,11 +107,9 @@ pub(crate) async fn get_backup_info_route( pub(crate) async fn delete_backup_version_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_backup(sender_user, &body.version) + .delete_backup(body.sender_user(), &body.version) .await; Ok(delete_backup_version::v3::Response {}) @@ -125,11 +126,9 @@ pub(crate) async fn delete_backup_version_route( pub(crate) async fn add_backup_keys_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if services .key_backups - .get_latest_backup_version(sender_user) + .get_latest_backup_version(body.sender_user()) .await .is_ok_and(|version| version != body.version) { @@ -142,7 +141,7 @@ pub(crate) async fn add_backup_keys_route( for (session_id, key_data) in &room.sessions { services .key_backups - .add_key(sender_user, &body.version, room_id, session_id, key_data) + .add_key(body.sender_user(), &body.version, room_id, session_id, key_data) .await?; } } @@ -150,12 +149,12 @@ pub(crate) async fn add_backup_keys_route( Ok(add_backup_keys::v3::Response { count: services .key_backups - .count_keys(sender_user, &body.version) + .count_keys(body.sender_user(), &body.version) .await .try_into()?, etag: services .key_backups - .get_etag(sender_user, &body.version) + .get_etag(body.sender_user(), &body.version) .await, }) } @@ -171,11 +170,9 @@ pub(crate) async fn add_backup_keys_route( pub(crate) async fn add_backup_keys_for_room_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if services .key_backups - .get_latest_backup_version(sender_user) + .get_latest_backup_version(body.sender_user()) .await .is_ok_and(|version| version != body.version) { @@ -187,19 +184,19 @@ pub(crate) async fn add_backup_keys_for_room_route( for (session_id, key_data) in &body.sessions { services .key_backups - .add_key(sender_user, &body.version, &body.room_id, session_id, key_data) + .add_key(body.sender_user(), &body.version, &body.room_id, session_id, key_data) .await?; } Ok(add_backup_keys_for_room::v3::Response { count: services .key_backups - .count_keys(sender_user, &body.version) + .count_keys(body.sender_user(), &body.version) .await .try_into()?, etag: services .key_backups - .get_etag(sender_user, &body.version) + .get_etag(body.sender_user(), &body.version) .await, }) } @@ -215,11 +212,9 @@ pub(crate) async fn add_backup_keys_for_room_route( pub(crate) async fn add_backup_keys_for_session_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if services .key_backups - .get_latest_backup_version(sender_user) + .get_latest_backup_version(body.sender_user()) .await .is_ok_and(|version| version != body.version) { @@ -230,18 +225,24 @@ pub(crate) async fn add_backup_keys_for_session_route( services .key_backups - .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data) + .add_key( + body.sender_user(), + &body.version, + &body.room_id, + &body.session_id, + &body.session_data, + ) .await?; Ok(add_backup_keys_for_session::v3::Response { count: services .key_backups - .count_keys(sender_user, &body.version) + .count_keys(body.sender_user(), &body.version) .await .try_into()?, etag: services .key_backups - .get_etag(sender_user, &body.version) + .get_etag(body.sender_user(), &body.version) .await, }) } @@ -252,11 +253,9 @@ pub(crate) async fn add_backup_keys_for_session_route( pub(crate) async fn get_backup_keys_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = services .key_backups - .get_all(sender_user, &body.version) + .get_all(body.sender_user(), &body.version) .await; Ok(get_backup_keys::v3::Response { @@ -270,11 +269,9 @@ pub(crate) async fn get_backup_keys_route( pub(crate) async fn get_backup_keys_for_room_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = services .key_backups - .get_room(sender_user, &body.version, &body.room_id) + .get_room(body.sender_user(), &body.version, &body.room_id) .await; Ok(get_backup_keys_for_room::v3::Response { @@ -288,11 +285,9 @@ pub(crate) async fn get_backup_keys_for_room_route( pub(crate) async fn get_backup_keys_for_session_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = services .key_backups - .get_session(sender_user, &body.version, &body.room_id, &body.session_id) + .get_session(body.sender_user(), &body.version, &body.room_id, &body.session_id) .await .map_err(|_| err!(Request(NotFound(debug_error!("Backup key not found for this user's session.")))))?; @@ -307,22 +302,20 @@ pub(crate) async fn get_backup_keys_for_session_route( pub(crate) async fn delete_backup_keys_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_all_keys(sender_user, &body.version) + .delete_all_keys(body.sender_user(), &body.version) .await; Ok(delete_backup_keys::v3::Response { count: services .key_backups - .count_keys(sender_user, &body.version) + .count_keys(body.sender_user(), &body.version) .await .try_into()?, etag: services .key_backups - .get_etag(sender_user, &body.version) + .get_etag(body.sender_user(), &body.version) .await, }) } @@ -333,22 +326,20 @@ pub(crate) async fn delete_backup_keys_route( pub(crate) async fn delete_backup_keys_for_room_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_room_keys(sender_user, &body.version, &body.room_id) + .delete_room_keys(body.sender_user(), &body.version, &body.room_id) .await; Ok(delete_backup_keys_for_room::v3::Response { count: services .key_backups - .count_keys(sender_user, &body.version) + .count_keys(body.sender_user(), &body.version) .await .try_into()?, etag: services .key_backups - .get_etag(sender_user, &body.version) + .get_etag(body.sender_user(), &body.version) .await, }) } @@ -359,22 +350,20 @@ pub(crate) async fn delete_backup_keys_for_room_route( pub(crate) async fn delete_backup_keys_for_session_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id) + .delete_room_key(body.sender_user(), &body.version, &body.room_id, &body.session_id) .await; Ok(delete_backup_keys_for_session::v3::Response { count: services .key_backups - .count_keys(sender_user, &body.version) + .count_keys(body.sender_user(), &body.version) .await .try_into()?, etag: services .key_backups - .get_etag(sender_user, &body.version) + .get_etag(body.sender_user(), &body.version) .await, }) } diff --git a/src/api/router/args.rs b/src/api/router/args.rs index 746e1cfc6..cefacac1c 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -3,17 +3,14 @@ use std::{mem, ops::Deref}; use axum::{async_trait, body::Body, extract::FromRequest}; use bytes::{BufMut, BytesMut}; use conduit::{debug, err, trace, utils::string::EMPTY, Error, Result}; -use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId}; +use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId}; use service::Services; use super::{auth, auth::Auth, request, request::Request}; use crate::{service::appservice::RegistrationInfo, State}; /// Extractor for Ruma request structs -pub(crate) struct Args -where - T: IncomingRequest + Send + Sync + 'static, -{ +pub(crate) struct Args { /// Request struct body pub(crate) body: T, @@ -38,6 +35,17 @@ where pub(crate) json_body: Option, } +impl Args +where + T: IncomingRequest + Send + Sync + 'static, +{ + #[inline] + pub(crate) fn sender_user(&self) -> &UserId { self.sender_user.as_deref().expect("user is authenticated") } + + #[inline] + pub(crate) fn origin(&self) -> &ServerName { self.origin.as_deref().expect("server is authenticated") } +} + #[async_trait] impl FromRequest for Args where diff --git a/src/api/router/handler.rs b/src/api/router/handler.rs index 3b7b1eeb0..0022f06a9 100644 --- a/src/api/router/handler.rs +++ b/src/api/router/handler.rs @@ -38,7 +38,7 @@ macro_rules! ruma_handler { where Fun: Fn($($tx,)* Ruma,) -> Fut + Send + Sync + 'static, Fut: Future> + Send, - Req: IncomingRequest + Send + Sync, + Req: IncomingRequest + Send + Sync + 'static, Err: IntoResponse + Send, ::OutgoingResponse: Send, $( $tx: FromRequestParts + Send + Sync + 'static, )* diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 2bbc95ca9..088b891a2 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -18,12 +18,10 @@ use crate::Ruma; pub(crate) async fn get_backfill_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; if !services @@ -33,7 +31,7 @@ pub(crate) async fn get_backfill_route( .await && !services .rooms .state_cache - .server_in_room(origin, &body.room_id) + .server_in_room(body.origin(), &body.room_id) .await { return Err!(Request(Forbidden("Server is not in room."))); @@ -59,6 +57,7 @@ pub(crate) async fn get_backfill_route( .try_into() .expect("UInt could not be converted to usize"); + let origin = body.origin(); let pdus = services .rooms .timeline diff --git a/src/api/server/event.rs b/src/api/server/event.rs index e4eac794f..64ce3e401 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -13,8 +13,6 @@ use crate::Ruma; pub(crate) async fn get_event_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - let event = services .rooms .timeline @@ -37,7 +35,7 @@ pub(crate) async fn get_event_route( .await && !services .rooms .state_cache - .server_in_room(origin, room_id) + .server_in_room(body.origin(), room_id) .await { return Err!(Request(Forbidden("Server is not in room."))); @@ -46,7 +44,7 @@ pub(crate) async fn get_event_route( if !services .rooms .state_accessor - .server_can_see_event(origin, room_id, &body.event_id) + .server_can_see_event(body.origin(), room_id, &body.event_id) .await? { return Err!(Request(Forbidden("Server is not allowed to see event."))); diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 8307a4ad3..8fe96f813 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -18,12 +18,10 @@ use crate::Ruma; pub(crate) async fn get_event_authorization_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; if !services @@ -33,7 +31,7 @@ pub(crate) async fn get_event_authorization_route( .await && !services .rooms .state_cache - .server_in_room(origin, &body.room_id) + .server_in_room(body.origin(), &body.room_id) .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index e267898fe..aee4fbe90 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -13,12 +13,10 @@ use crate::Ruma; pub(crate) async fn get_missing_events_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; if !services @@ -28,7 +26,7 @@ pub(crate) async fn get_missing_events_route( .await && !services .rooms .state_cache - .server_in_room(origin, &body.room_id) + .server_in_room(body.origin(), &body.room_id) .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); @@ -71,7 +69,7 @@ pub(crate) async fn get_missing_events_route( if !services .rooms .state_accessor - .server_can_see_event(origin, &body.room_id, &queued_events[i]) + .server_can_see_event(body.origin(), &body.room_id, &queued_events[i]) .await? { i = i.saturating_add(1); diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index 002bd7633..e3ce71084 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -10,13 +10,11 @@ use crate::{Error, Result, Ruma}; pub(crate) async fn get_hierarchy_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - if services.rooms.metadata.exists(&body.room_id).await { services .rooms .spaces - .get_federation_hierarchy(&body.room_id, origin, body.suggested_only) + .get_federation_hierarchy(&body.room_id, body.origin(), body.suggested_only) .await } else { Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist.")) diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index a9e404c52..b30a1b584 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -18,13 +18,11 @@ pub(crate) async fn create_invite_route( State(services): State, InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - // ACL check origin services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; if !services @@ -55,10 +53,11 @@ pub(crate) async fn create_invite_route( .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { warn!( - "Received federated/remote invite from banned server {origin} for room ID {}. Rejecting.", + "Received federated/remote invite from banned server {} for room ID {}. Rejecting.", + body.origin(), body.room_id ); diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index 856680382..c3524f0e4 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -30,8 +30,7 @@ pub(crate) async fn create_join_event_template_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } - let origin = body.origin.as_ref().expect("server is authenticated"); - if body.user_id.server_name() != origin { + if body.user_id.server_name() != body.origin() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Not allowed to join on behalf of another server/user", @@ -42,19 +41,21 @@ pub(crate) async fn create_join_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; if services .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { warn!( - "Server {origin} for remote user {} tried joining room ID {} which has a server name that is globally \ + "Server {} for remote user {} tried joining room ID {} which has a server name that is globally \ forbidden. Rejecting.", - &body.user_id, &body.room_id, + body.origin(), + &body.user_id, + &body.room_id, ); return Err(Error::BadRequest( ErrorKind::forbidden(), diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 81a32c865..33a945603 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -19,8 +19,7 @@ pub(crate) async fn create_leave_event_template_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } - let origin = body.origin.as_ref().expect("server is authenticated"); - if body.user_id.server_name() != origin { + if body.user_id.server_name() != body.origin() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Not allowed to leave on behalf of another server/user", @@ -31,7 +30,7 @@ pub(crate) async fn create_leave_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; diff --git a/src/api/server/send.rs b/src/api/server/send.rs index d5d3ffbbf..2da99c936 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -41,9 +41,7 @@ pub(crate) async fn send_transaction_message_route( State(services): State, InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - - if *origin != body.body.origin { + if body.origin() != body.body.origin { return Err!(Request(Forbidden( "Not allowed to send transactions on behalf of other servers" ))); @@ -67,19 +65,19 @@ pub(crate) async fn send_transaction_message_route( edus = ?body.edus.len(), elapsed = ?txn_start_time.elapsed(), id = ?body.transaction_id, - origin =?body.origin, + origin =?body.origin(), "Starting txn", ); - let resolved_map = handle_pdus(&services, &client, &body.pdus, origin, &txn_start_time).await?; - handle_edus(&services, &client, &body.edus, origin).await; + let resolved_map = handle_pdus(&services, &client, &body.pdus, body.origin(), &txn_start_time).await?; + handle_edus(&services, &client, &body.edus, body.origin()).await; debug!( pdus = ?body.pdus.len(), edus = ?body.edus.len(), elapsed = ?txn_start_time.elapsed(), id = ?body.transaction_id, - origin =?body.origin, + origin =?body.origin(), "Finished txn", ); diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index d888d75e8..c3273bafb 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -217,16 +217,15 @@ async fn create_join_event( pub(crate) async fn create_join_event_v1_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - if services .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { warn!( - "Server {origin} tried joining room ID {} who has a server name that is globally forbidden. Rejecting.", + "Server {} tried joining room ID {} who has a server name that is globally forbidden. Rejecting.", + body.origin(), &body.room_id, ); return Err(Error::BadRequest( @@ -243,8 +242,8 @@ pub(crate) async fn create_join_event_v1_route( .contains(&server.to_owned()) { warn!( - "Server {origin} tried joining room ID {} which has a server name that is globally forbidden. \ - Rejecting.", + "Server {} tried joining room ID {} which has a server name that is globally forbidden. Rejecting.", + body.origin(), &body.room_id, ); return Err(Error::BadRequest( @@ -254,7 +253,7 @@ pub(crate) async fn create_join_event_v1_route( } } - let room_state = create_join_event(&services, origin, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(&services, body.origin(), &body.room_id, &body.pdu).await?; Ok(create_join_event::v1::Response { room_state, @@ -267,13 +266,11 @@ pub(crate) async fn create_join_event_v1_route( pub(crate) async fn create_join_event_v2_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - if services .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -299,7 +296,7 @@ pub(crate) async fn create_join_event_v2_route( auth_chain, state, event, - } = create_join_event(&services, origin, &body.room_id, &body.pdu).await?; + } = create_join_event(&services, body.origin(), &body.room_id, &body.pdu).await?; let room_state = create_join_event::v2::RoomState { members_omitted: false, auth_chain, diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index 0530f9dd5..7b4a8aeef 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -8,7 +8,7 @@ use ruma::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, }, - OwnedServerName, OwnedUserId, RoomId, ServerName, + OwnedUserId, RoomId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; @@ -23,9 +23,7 @@ use crate::{ pub(crate) async fn create_leave_event_v1_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - - create_leave_event(&services, origin, &body.room_id, &body.pdu).await?; + create_leave_event(&services, body.origin(), &body.room_id, &body.pdu).await?; Ok(create_leave_event::v1::Response::new()) } @@ -36,9 +34,7 @@ pub(crate) async fn create_leave_event_v1_route( pub(crate) async fn create_leave_event_v2_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - - create_leave_event(&services, origin, &body.room_id, &body.pdu).await?; + create_leave_event(&services, body.origin(), &body.room_id, &body.pdu).await?; Ok(create_leave_event::v2::Response::new()) } @@ -139,16 +135,6 @@ async fn create_leave_event( )); } - let origin: OwnedServerName = serde_json::from_value( - serde_json::to_value( - value - .get("origin") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing origin property."))?, - ) - .expect("CanonicalJson is valid json value"), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; - let mutex_lock = services .rooms .event_handler @@ -159,7 +145,7 @@ async fn create_leave_event( let pdu_id: Vec = services .rooms .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value, true) + .handle_incoming_pdu(origin, room_id, &event_id, value, true) .await? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 3a27cd0a3..59bb6c7b1 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -13,12 +13,10 @@ use crate::Ruma; pub(crate) async fn get_room_state_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; if !services @@ -28,7 +26,7 @@ pub(crate) async fn get_room_state_route( .await && !services .rooms .state_cache - .server_in_room(origin, &body.room_id) + .server_in_room(body.origin(), &body.room_id) .await { return Err!(Request(Forbidden("Server is not in room."))); diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index b026abf1d..957a2a86e 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -14,12 +14,10 @@ use crate::{Result, Ruma}; pub(crate) async fn get_room_state_ids_route( State(services): State, body: Ruma, ) -> Result { - let origin = body.origin.as_ref().expect("server is authenticated"); - services .rooms .event_handler - .acl_check(origin, &body.room_id) + .acl_check(body.origin(), &body.room_id) .await?; if !services @@ -29,7 +27,7 @@ pub(crate) async fn get_room_state_ids_route( .await && !services .rooms .state_cache - .server_in_room(origin, &body.room_id) + .server_in_room(body.origin(), &body.room_id) .await { return Err!(Request(Forbidden("Server is not in room."))); diff --git a/src/api/server/user.rs b/src/api/server/user.rs index 0718da580..40f330a12 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -27,8 +27,6 @@ pub(crate) async fn get_devices_route( )); } - let origin = body.origin.as_ref().expect("server is authenticated"); - let user_id = &body.user_id; Ok(get_devices::v1::Response { user_id: user_id.clone(), @@ -66,12 +64,12 @@ pub(crate) async fn get_devices_route( .await, master_key: services .users - .get_master_key(None, &body.user_id, &|u| u.server_name() == origin) + .get_master_key(None, &body.user_id, &|u| u.server_name() == body.origin()) .await .ok(), self_signing_key: services .users - .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin) + .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == body.origin()) .await .ok(), }) From 8742266ff0422fb678c86306d2d7384ff7081fe4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 25 Oct 2024 01:16:01 +0000 Subject: [PATCH 126/245] split up core/pdu Signed-off-by: Jason Volk --- src/core/pdu/content.rs | 20 ++ src/core/pdu/id.rs | 27 +++ src/core/pdu/mod.rs | 456 +++----------------------------------- src/core/pdu/redact.rs | 93 ++++++++ src/core/pdu/state_res.rs | 30 +++ src/core/pdu/strip.rs | 208 +++++++++++++++++ src/core/pdu/unsigned.rs | 83 +++++++ 7 files changed, 492 insertions(+), 425 deletions(-) create mode 100644 src/core/pdu/content.rs create mode 100644 src/core/pdu/id.rs create mode 100644 src/core/pdu/redact.rs create mode 100644 src/core/pdu/state_res.rs create mode 100644 src/core/pdu/strip.rs create mode 100644 src/core/pdu/unsigned.rs diff --git a/src/core/pdu/content.rs b/src/core/pdu/content.rs new file mode 100644 index 000000000..a6d86554b --- /dev/null +++ b/src/core/pdu/content.rs @@ -0,0 +1,20 @@ +use serde::Deserialize; +use serde_json::value::Value as JsonValue; + +use crate::{err, implement, Result}; + +#[must_use] +#[implement(super::PduEvent)] +pub fn get_content_as_value(&self) -> JsonValue { + self.get_content() + .expect("pdu content must be a valid JSON value") +} + +#[implement(super::PduEvent)] +pub fn get_content(&self) -> Result +where + T: for<'de> Deserialize<'de>, +{ + serde_json::from_str(self.content.get()) + .map_err(|e| err!(Database("Failed to deserialize pdu content into type: {e}"))) +} diff --git a/src/core/pdu/id.rs b/src/core/pdu/id.rs new file mode 100644 index 000000000..ae5b85f9a --- /dev/null +++ b/src/core/pdu/id.rs @@ -0,0 +1,27 @@ +use ruma::{CanonicalJsonObject, OwnedEventId, RoomVersionId}; +use serde_json::value::RawValue as RawJsonValue; + +use crate::{err, Result}; + +/// Generates a correct eventId for the incoming pdu. +/// +/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. +pub fn gen_event_id_canonical_json( + pdu: &RawJsonValue, room_version_id: &RoomVersionId, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) + .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; + + let event_id = gen_event_id(&value, room_version_id)?; + + Ok((event_id, value)) +} + +/// Generates a correct eventId for the incoming pdu. +pub fn gen_event_id(value: &CanonicalJsonObject, room_version_id: &RoomVersionId) -> Result { + let reference_hash = ruma::signatures::reference_hash(value, room_version_id)?; + let event_id: OwnedEventId = format!("${reference_hash}").try_into()?; + + Ok(event_id) +} diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 274b96bd2..9970c39e2 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -1,44 +1,28 @@ mod builder; +mod content; mod count; +mod id; +mod redact; +mod state_res; +mod strip; +mod unsigned; -use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; +use std::{cmp::Ordering, sync::Arc}; use ruma::{ - canonical_json::redact_content_in_place, - events::{ - room::{member::RoomMemberEventContent, redaction::RoomRedactionEventContent}, - space::child::HierarchySpaceChildEvent, - AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, - AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType, - }, - serde::Raw, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, - OwnedUserId, RoomId, RoomVersionId, UInt, UserId, + events::TimelineEventType, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedRoomId, OwnedUserId, UInt, }; use serde::{Deserialize, Serialize}; -use serde_json::{ - json, - value::{to_raw_value, RawValue as RawJsonValue, Value as JsonValue}, -}; +use serde_json::value::RawValue as RawJsonValue; pub use self::{ builder::{Builder, Builder as PduBuilder}, count::PduCount, + id::*, }; -use crate::{err, is_true, warn, Error, Result}; - -#[derive(Deserialize)] -struct ExtractRedactedBecause { - redacted_because: Option, -} - -/// Content hashes of a PDU. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct EventHash { - /// The SHA-256 hash. - pub sha256: String, -} +use crate::Result; +/// Persistent Data Unit (Event) #[derive(Clone, Deserialize, Serialize, Debug)] pub struct PduEvent { pub event_id: Arc, @@ -65,415 +49,37 @@ pub struct PduEvent { pub signatures: Option>, } -impl PduEvent { - #[tracing::instrument(skip(self), level = "debug")] - pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> { - self.unsigned = None; - - let mut content = serde_json::from_str(self.content.get()) - .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; - - redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) - .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; - - self.unsigned = Some( - to_raw_value(&json!({ - "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") - })) - .expect("to string always works"), - ); - - self.content = to_raw_value(&content).expect("to string always works"); - - Ok(()) - } - - #[must_use] - pub fn is_redacted(&self) -> bool { - let Some(unsigned) = &self.unsigned else { - return false; - }; - - let Ok(unsigned) = ExtractRedactedBecause::deserialize(&**unsigned) else { - return false; - }; - - unsigned.redacted_because.is_some() - } - - pub fn remove_transaction_id(&mut self) -> Result<()> { - let Some(unsigned) = &self.unsigned else { - return Ok(()); - }; - - let mut unsigned: BTreeMap> = - serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; - - unsigned.remove("transaction_id"); - self.unsigned = to_raw_value(&unsigned) - .map(Some) - .expect("unsigned is valid"); - - Ok(()) - } - - pub fn add_age(&mut self) -> Result<()> { - let mut unsigned: BTreeMap> = self - .unsigned - .as_ref() - .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) - .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; - - // deliberately allowing for the possibility of negative age - let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); - let then: i128 = self.origin_server_ts.into(); - let this_age = now.saturating_sub(then); - - unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid")); - self.unsigned = to_raw_value(&unsigned) - .map(Some) - .expect("unsigned is valid"); - - Ok(()) - } - - /// Copies the `redacts` property of the event to the `content` dict and - /// vice-versa. - /// - /// This follows the specification's - /// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): - /// - /// > For backwards-compatibility with older clients, servers should add a - /// > redacts - /// > property to the top level of m.room.redaction events in when serving - /// > such events - /// > over the Client-Server API. - /// - /// > For improved compatibility with newer clients, servers should add a - /// > redacts property - /// > to the content of m.room.redaction events in older room versions when - /// > serving - /// > such events over the Client-Server API. - #[must_use] - pub fn copy_redacts(&self) -> (Option>, Box) { - if self.kind == TimelineEventType::RoomRedaction { - if let Ok(mut content) = serde_json::from_str::(self.content.get()) { - if let Some(redacts) = content.redacts { - return (Some(redacts.into()), self.content.clone()); - } else if let Some(redacts) = self.redacts.clone() { - content.redacts = Some(redacts.into()); - return ( - self.redacts.clone(), - to_raw_value(&content).expect("Must be valid, we only added redacts field"), - ); - } - } - } - - (self.redacts.clone(), self.content.clone()) - } - - #[must_use] - pub fn get_content_as_value(&self) -> JsonValue { - self.get_content() - .expect("pdu content must be a valid JSON value") - } - - pub fn get_content(&self) -> Result - where - T: for<'de> Deserialize<'de>, - { - serde_json::from_str(self.content.get()) - .map_err(|e| err!(Database("Failed to deserialize pdu content into type: {e}"))) - } - - pub fn contains_unsigned_property(&self, property: &str, is_type: F) -> bool - where - F: FnOnce(&JsonValue) -> bool, - { - self.get_unsigned_as_value() - .get(property) - .map(is_type) - .is_some_and(is_true!()) - } - - pub fn get_unsigned_property(&self, property: &str) -> Result - where - T: for<'de> Deserialize<'de>, - { - self.get_unsigned_as_value() - .get_mut(property) - .map(JsonValue::take) - .map(serde_json::from_value) - .ok_or(err!(Request(NotFound("property not found in unsigned object"))))? - .map_err(|e| err!(Database("Failed to deserialize unsigned.{property} into type: {e}"))) - } - - #[must_use] - pub fn get_unsigned_as_value(&self) -> JsonValue { self.get_unsigned::().unwrap_or_default() } - - pub fn get_unsigned(&self) -> Result { - self.unsigned - .as_ref() - .map(|raw| raw.get()) - .map(serde_json::from_str) - .ok_or(err!(Request(NotFound("\"unsigned\" property not found in pdu"))))? - .map_err(|e| err!(Database("Failed to deserialize \"unsigned\" into value: {e}"))) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_sync_room_event(&self) -> Raw { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - /// This only works for events that are also AnyRoomEvents. - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_any_event(&self) -> Raw { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_room_event(&self) -> Raw { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_message_like_event(&self) -> Raw { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[must_use] - pub fn to_state_event_value(&self) -> JsonValue { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - "state_key": self.state_key, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - - json - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_state_event(&self) -> Raw { - serde_json::from_value(self.to_state_event_value()).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_sync_state_event(&self) -> Raw { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "state_key": self.state_key, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_stripped_state_event(&self) -> Raw { - let json = json!({ - "content": self.content, - "type": self.kind, - "sender": self.sender, - "state_key": self.state_key, - }); - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_stripped_spacechild_state_event(&self) -> Raw { - let json = json!({ - "content": self.content, - "type": self.kind, - "sender": self.sender, - "state_key": self.state_key, - "origin_server_ts": self.origin_server_ts, - }); - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_member_event(&self) -> Raw> { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "redacts": self.redacts, - "room_id": self.room_id, - "state_key": self.state_key, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } +/// Content hashes of a PDU. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct EventHash { + /// The SHA-256 hash. + pub sha256: String, +} +impl PduEvent { pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { - json.insert("event_id".into(), CanonicalJsonValue::String(event_id.into())); - - let value = serde_json::to_value(json)?; - let pdu = serde_json::from_value(value)?; - - Ok(pdu) + let event_id = CanonicalJsonValue::String(event_id.into()); + json.insert("event_id".into(), event_id); + serde_json::to_value(json) + .and_then(serde_json::from_value) + .map_err(Into::into) } } -impl state_res::Event for PduEvent { - type Id = Arc; - - fn event_id(&self) -> &Self::Id { &self.event_id } - - fn room_id(&self) -> &RoomId { &self.room_id } - - fn sender(&self) -> &UserId { &self.sender } - - fn event_type(&self) -> &TimelineEventType { &self.kind } - - fn content(&self) -> &RawJsonValue { &self.content } - - fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) } - - fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - - fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.prev_events.iter() } - - fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.auth_events.iter() } - - fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } -} - -// These impl's allow us to dedup state snapshots when resolving state -// for incoming events (federation/send/{txn}). +/// Prevent derived equality which wouldn't limit itself to event_id impl Eq for PduEvent {} + +/// Equality determined by the Pdu's ID, not the memory representations. impl PartialEq for PduEvent { fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id } } + +/// Ordering determined by the Pdu's ID, not the memory representations. impl PartialOrd for PduEvent { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } + +/// Ordering determined by the Pdu's ID, not the memory representations. impl Ord for PduEvent { fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) } } - -/// Generates a correct eventId for the incoming pdu. -/// -/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. -pub fn gen_event_id_canonical_json( - pdu: &RawJsonValue, room_version_id: &RoomVersionId, -) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) - .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; - - let event_id = gen_event_id(&value, room_version_id)?; - - Ok((event_id, value)) -} - -/// Generates a correct eventId for the incoming pdu. -pub fn gen_event_id(value: &CanonicalJsonObject, room_version_id: &RoomVersionId) -> Result { - let reference_hash = ruma::signatures::reference_hash(value, room_version_id)?; - let event_id: OwnedEventId = format!("${reference_hash}").try_into()?; - - Ok(event_id) -} diff --git a/src/core/pdu/redact.rs b/src/core/pdu/redact.rs new file mode 100644 index 000000000..647f54c0f --- /dev/null +++ b/src/core/pdu/redact.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; + +use ruma::{ + canonical_json::redact_content_in_place, + events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, + EventId, RoomVersionId, +}; +use serde::Deserialize; +use serde_json::{ + json, + value::{to_raw_value, RawValue as RawJsonValue}, +}; + +use crate::{implement, warn, Error, Result}; + +#[derive(Deserialize)] +struct ExtractRedactedBecause { + redacted_because: Option, +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> { + self.unsigned = None; + + let mut content = + serde_json::from_str(self.content.get()).map_err(|_| Error::bad_database("PDU in db has invalid content."))?; + + redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) + .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; + + self.unsigned = Some( + to_raw_value(&json!({ + "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") + })) + .expect("to string always works"), + ); + + self.content = to_raw_value(&content).expect("to string always works"); + + Ok(()) +} + +#[implement(super::PduEvent)] +#[must_use] +pub fn is_redacted(&self) -> bool { + let Some(unsigned) = &self.unsigned else { + return false; + }; + + let Ok(unsigned) = ExtractRedactedBecause::deserialize(&**unsigned) else { + return false; + }; + + unsigned.redacted_because.is_some() +} + +/// Copies the `redacts` property of the event to the `content` dict and +/// vice-versa. +/// +/// This follows the specification's +/// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): +/// +/// > For backwards-compatibility with older clients, servers should add a +/// > redacts +/// > property to the top level of m.room.redaction events in when serving +/// > such events +/// > over the Client-Server API. +/// +/// > For improved compatibility with newer clients, servers should add a +/// > redacts property +/// > to the content of m.room.redaction events in older room versions when +/// > serving +/// > such events over the Client-Server API. +#[implement(super::PduEvent)] +#[must_use] +pub fn copy_redacts(&self) -> (Option>, Box) { + if self.kind == TimelineEventType::RoomRedaction { + if let Ok(mut content) = serde_json::from_str::(self.content.get()) { + if let Some(redacts) = content.redacts { + return (Some(redacts.into()), self.content.clone()); + } else if let Some(redacts) = self.redacts.clone() { + content.redacts = Some(redacts.into()); + return ( + self.redacts.clone(), + to_raw_value(&content).expect("Must be valid, we only added redacts field"), + ); + } + } + } + + (self.redacts.clone(), self.content.clone()) +} diff --git a/src/core/pdu/state_res.rs b/src/core/pdu/state_res.rs new file mode 100644 index 000000000..a27c98229 --- /dev/null +++ b/src/core/pdu/state_res.rs @@ -0,0 +1,30 @@ +use std::sync::Arc; + +use ruma::{events::TimelineEventType, state_res, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId}; +use serde_json::value::RawValue as RawJsonValue; + +use super::PduEvent; + +impl state_res::Event for PduEvent { + type Id = Arc; + + fn event_id(&self) -> &Self::Id { &self.event_id } + + fn room_id(&self) -> &RoomId { &self.room_id } + + fn sender(&self) -> &UserId { &self.sender } + + fn event_type(&self) -> &TimelineEventType { &self.kind } + + fn content(&self) -> &RawJsonValue { &self.content } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) } + + fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } + + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.prev_events.iter() } + + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.auth_events.iter() } + + fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } +} diff --git a/src/core/pdu/strip.rs b/src/core/pdu/strip.rs new file mode 100644 index 000000000..8d20d9828 --- /dev/null +++ b/src/core/pdu/strip.rs @@ -0,0 +1,208 @@ +use ruma::{ + events::{ + room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, AnyEphemeralRoomEvent, + AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent, + AnyTimelineEvent, StateEvent, + }, + serde::Raw, +}; +use serde_json::{json, value::Value as JsonValue}; + +use crate::{implement, warn}; + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_sync_room_event(&self) -> Raw { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +/// This only works for events that are also AnyRoomEvents. +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_any_event(&self) -> Raw { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_room_event(&self) -> Raw { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_message_like_event(&self) -> Raw { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::PduEvent)] +#[must_use] +pub fn to_state_event_value(&self) -> JsonValue { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + "state_key": self.state_key, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + + json +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_state_event(&self) -> Raw { + serde_json::from_value(self.to_state_event_value()).expect("Raw::from_value always works") +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_sync_state_event(&self) -> Raw { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "state_key": self.state_key, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_stripped_state_event(&self) -> Raw { + let json = json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + }); + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_stripped_spacechild_state_event(&self) -> Raw { + let json = json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + "origin_server_ts": self.origin_server_ts, + }); + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::PduEvent)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_member_event(&self) -> Raw> { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "redacts": self.redacts, + "room_id": self.room_id, + "state_key": self.state_key, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} diff --git a/src/core/pdu/unsigned.rs b/src/core/pdu/unsigned.rs new file mode 100644 index 000000000..1c47e8263 --- /dev/null +++ b/src/core/pdu/unsigned.rs @@ -0,0 +1,83 @@ +use std::collections::BTreeMap; + +use ruma::MilliSecondsSinceUnixEpoch; +use serde::Deserialize; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue, Value as JsonValue}; + +use crate::{err, implement, is_true, Result}; + +#[implement(super::PduEvent)] +pub fn remove_transaction_id(&mut self) -> Result<()> { + let Some(unsigned) = &self.unsigned else { + return Ok(()); + }; + + let mut unsigned: BTreeMap> = + serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + unsigned.remove("transaction_id"); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); + + Ok(()) +} + +#[implement(super::PduEvent)] +pub fn add_age(&mut self) -> Result<()> { + let mut unsigned: BTreeMap> = self + .unsigned + .as_ref() + .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + // deliberately allowing for the possibility of negative age + let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); + let then: i128 = self.origin_server_ts.into(); + let this_age = now.saturating_sub(then); + + unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid")); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); + + Ok(()) +} + +#[implement(super::PduEvent)] +pub fn contains_unsigned_property(&self, property: &str, is_type: F) -> bool +where + F: FnOnce(&JsonValue) -> bool, +{ + self.get_unsigned_as_value() + .get(property) + .map(is_type) + .is_some_and(is_true!()) +} + +#[implement(super::PduEvent)] +pub fn get_unsigned_property(&self, property: &str) -> Result +where + T: for<'de> Deserialize<'de>, +{ + self.get_unsigned_as_value() + .get_mut(property) + .map(JsonValue::take) + .map(serde_json::from_value) + .ok_or(err!(Request(NotFound("property not found in unsigned object"))))? + .map_err(|e| err!(Database("Failed to deserialize unsigned.{property} into type: {e}"))) +} + +#[implement(super::PduEvent)] +#[must_use] +pub fn get_unsigned_as_value(&self) -> JsonValue { self.get_unsigned::().unwrap_or_default() } + +#[implement(super::PduEvent)] +pub fn get_unsigned(&self) -> Result { + self.unsigned + .as_ref() + .map(|raw| raw.get()) + .map(serde_json::from_str) + .ok_or(err!(Request(NotFound("\"unsigned\" property not found in pdu"))))? + .map_err(|e| err!(Database("Failed to deserialize \"unsigned\" into value: {e}"))) +} From cf59f738b9f687aa0902bf2bd011219567e4fea5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 25 Oct 2024 02:01:29 +0000 Subject: [PATCH 127/245] move macros incorrectly moved out of utils to top level Signed-off-by: Jason Volk --- src/core/mod.rs | 56 ----------------------------------- src/core/utils/mod.rs | 68 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 56 deletions(-) diff --git a/src/core/mod.rs b/src/core/mod.rs index 790525549..d201709bd 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -36,59 +36,3 @@ pub mod mods { () => {}; } } - -/// Functor for falsy -#[macro_export] -macro_rules! is_false { - () => { - |x| !x - }; -} - -/// Functor for truthy -#[macro_export] -macro_rules! is_true { - () => { - |x| !!x - }; -} - -/// Functor for equality to zero -#[macro_export] -macro_rules! is_zero { - () => { - $crate::is_matching!(0) - }; -} - -/// Functor for equality i.e. .is_some_and(is_equal!(2)) -#[macro_export] -macro_rules! is_equal_to { - ($val:expr) => { - |x| x == $val - }; -} - -/// Functor for less i.e. .is_some_and(is_less_than!(2)) -#[macro_export] -macro_rules! is_less_than { - ($val:expr) => { - |x| x < $val - }; -} - -/// Functor for matches! i.e. .is_some_and(is_matching!('A'..='Z')) -#[macro_export] -macro_rules! is_matching { - ($val:expr) => { - |x| matches!(x, $val) - }; -} - -/// Functor for !is_empty() -#[macro_export] -macro_rules! is_not_empty { - () => { - |x| !x.is_empty() - }; -} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 3943a8daa..8e29c608f 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -57,3 +57,71 @@ macro_rules! at { |t| t.$idx }; } + +/// Functor for equality i.e. .is_some_and(is_equal!(2)) +#[macro_export] +macro_rules! is_equal_to { + ($val:ident) => { + |x| x == $val + }; + + ($val:expr) => { + |x| x == $val + }; +} + +/// Functor for less i.e. .is_some_and(is_less_than!(2)) +#[macro_export] +macro_rules! is_less_than { + ($val:ident) => { + |x| x < $val + }; + + ($val:expr) => { + |x| x < $val + }; +} + +/// Functor for equality to zero +#[macro_export] +macro_rules! is_zero { + () => { + $crate::is_matching!(0) + }; +} + +/// Functor for matches! i.e. .is_some_and(is_matching!('A'..='Z')) +#[macro_export] +macro_rules! is_matching { + ($val:ident) => { + |x| matches!(x, $val) + }; + + ($val:expr) => { + |x| matches!(x, $val) + }; +} + +/// Functor for !is_empty() +#[macro_export] +macro_rules! is_not_empty { + () => { + |x| !x.is_empty() + }; +} + +/// Functor for truthy +#[macro_export] +macro_rules! is_true { + () => { + |x| !!x + }; +} + +/// Functor for falsy +#[macro_export] +macro_rules! is_false { + () => { + |x| !x + }; +} From b7369074d4d9e235c4bb9a7529e98c1aa5a662b1 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 25 Oct 2024 02:56:24 +0000 Subject: [PATCH 128/245] add RoomEventFilter matcher for PduEvent Signed-off-by: Jason Volk --- src/core/pdu/filter.rs | 90 ++++++++++++++++++++++++++++++++++++++++++ src/core/pdu/mod.rs | 1 + 2 files changed, 91 insertions(+) create mode 100644 src/core/pdu/filter.rs diff --git a/src/core/pdu/filter.rs b/src/core/pdu/filter.rs new file mode 100644 index 000000000..bd232ebd8 --- /dev/null +++ b/src/core/pdu/filter.rs @@ -0,0 +1,90 @@ +use ruma::api::client::filter::{RoomEventFilter, UrlFilter}; +use serde_json::Value; + +use crate::{implement, is_equal_to}; + +#[implement(super::PduEvent)] +#[must_use] +pub fn matches(&self, filter: &RoomEventFilter) -> bool { + if !self.matches_sender(filter) { + return false; + } + + if !self.matches_room(filter) { + return false; + } + + if !self.matches_type(filter) { + return false; + } + + if !self.matches_url(filter) { + return false; + } + + true +} + +#[implement(super::PduEvent)] +fn matches_room(&self, filter: &RoomEventFilter) -> bool { + if filter.not_rooms.contains(&self.room_id) { + return false; + } + + if let Some(rooms) = filter.rooms.as_ref() { + if !rooms.contains(&self.room_id) { + return false; + } + } + + true +} + +#[implement(super::PduEvent)] +fn matches_sender(&self, filter: &RoomEventFilter) -> bool { + if filter.not_senders.contains(&self.sender) { + return false; + } + + if let Some(senders) = filter.senders.as_ref() { + if !senders.contains(&self.sender) { + return false; + } + } + + true +} + +#[implement(super::PduEvent)] +fn matches_type(&self, filter: &RoomEventFilter) -> bool { + let event_type = &self.kind.to_cow_str(); + if filter.not_types.iter().any(is_equal_to!(event_type)) { + return false; + } + + if let Some(types) = filter.types.as_ref() { + if !types.iter().any(is_equal_to!(event_type)) { + return false; + } + } + + true +} + +#[implement(super::PduEvent)] +fn matches_url(&self, filter: &RoomEventFilter) -> bool { + let Some(url_filter) = filter.url_filter.as_ref() else { + return true; + }; + + //TODO: might be better to use Ruma's Raw rather than serde here + let url = serde_json::from_str::(self.content.get()) + .expect("parsing content JSON failed") + .get("url") + .is_some_and(Value::is_string); + + match url_filter { + UrlFilter::EventsWithUrl => url, + UrlFilter::EventsWithoutUrl => !url, + } +} diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 9970c39e2..ed11adbb2 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -1,6 +1,7 @@ mod builder; mod content; mod count; +mod filter; mod id; mod redact; mod state_res; From 68086717516225af96fe6c7cff743836103188eb Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 25 Oct 2024 05:22:50 +0000 Subject: [PATCH 129/245] merge search service w/ data Signed-off-by: Jason Volk --- src/service/rooms/search/data.rs | 113 ---------------------------- src/service/rooms/search/mod.rs | 123 ++++++++++++++++++++++++++----- 2 files changed, 103 insertions(+), 133 deletions(-) delete mode 100644 src/service/rooms/search/data.rs diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs deleted file mode 100644 index de98beeeb..000000000 --- a/src/service/rooms/search/data.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::sync::Arc; - -use conduit::utils::{set, stream::TryIgnore, IterStream, ReadyExt}; -use database::Map; -use futures::StreamExt; -use ruma::RoomId; - -use crate::{rooms, Dep}; - -pub(super) struct Data { - tokenids: Arc, - services: Services, -} - -struct Services { - short: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - tokenids: db["tokenids"].clone(), - services: Services { - short: args.depend::("rooms::short"), - }, - } - } - - pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { - let batch = tokenize(message_body) - .map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here - (key, Vec::::new()) - }) - .collect::>(); - - self.tokenids.insert_batch(batch.iter()); - } - - pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { - let batch = tokenize(message_body).map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here - key - }); - - for token in batch { - self.tokenids.remove(&token); - } - } - - pub(super) async fn search_pdus( - &self, room_id: &RoomId, search_string: &str, - ) -> Option<(Vec>, Vec)> { - let prefix = self - .services - .short - .get_shortroomid(room_id) - .await - .ok()? - .to_be_bytes() - .to_vec(); - - let words: Vec<_> = tokenize(search_string).collect(); - - let bufs: Vec<_> = words - .clone() - .into_iter() - .stream() - .then(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.tokenids - .rev_raw_keys_from(&last_possible_id) // Newest pdus first - .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix2)) - .map(move |key| key[prefix3.len()..].to_vec()) - .collect::>() - }) - .collect() - .await; - - Some(( - set::intersection(bufs.iter().map(|buf| buf.iter())) - .cloned() - .collect(), - words, - )) - } -} - -/// Splits a string into tokens used as keys in the search inverted index -/// -/// This may be used to tokenize both message bodies (for indexing) or search -/// queries (for querying). -fn tokenize(body: &str) -> impl Iterator + Send + '_ { - body.split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) - .map(str::to_lowercase) -} diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 80b588044..032ad55cc 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,41 +1,124 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{ + implement, + utils::{set, stream::TryIgnore, IterStream, ReadyExt}, + Result, +}; +use database::Map; +use futures::StreamExt; use ruma::RoomId; +use crate::{rooms, Dep}; + pub struct Service { db: Data, + services: Services, +} + +struct Data { + tokenids: Arc, +} + +struct Services { + short: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + tokenids: args.db["tokenids"].clone(), + }, + services: Services { + short: args.depend::("rooms::short"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[inline] - #[tracing::instrument(skip(self), level = "debug")] - pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { - self.db.index_pdu(shortroomid, pdu_id, message_body); - } +#[implement(Service)] +pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + let batch = tokenize(message_body) + .map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + (key, Vec::::new()) + }) + .collect::>(); - #[inline] - #[tracing::instrument(skip(self), level = "debug")] - pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { - self.db.deindex_pdu(shortroomid, pdu_id, message_body); - } + self.db.tokenids.insert_batch(batch.iter()); +} - #[inline] - #[tracing::instrument(skip(self), level = "debug")] - pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec>, Vec)> { - self.db.search_pdus(room_id, search_string).await +#[implement(Service)] +pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + let batch = tokenize(message_body).map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + key + }); + + for token in batch { + self.db.tokenids.remove(&token); } } + +#[implement(Service)] +pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec>, Vec)> { + let prefix = self + .services + .short + .get_shortroomid(room_id) + .await + .ok()? + .to_be_bytes() + .to_vec(); + + let words: Vec<_> = tokenize(search_string).collect(); + + let bufs: Vec<_> = words + .clone() + .into_iter() + .stream() + .then(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xFF); + let prefix3 = prefix2.clone(); + + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.db.tokenids + .rev_raw_keys_from(&last_possible_id) // Newest pdus first + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix2)) + .map(move |key| key[prefix3.len()..].to_vec()) + .collect::>() + }) + .collect() + .await; + + let bufs = bufs.iter().map(|buf| buf.iter()); + + let results = set::intersection(bufs).cloned().collect(); + + Some((results, words)) +} + +/// Splits a string into tokens used as keys in the search inverted index +/// +/// This may be used to tokenize both message bodies (for indexing) or search +/// queries (for querying). +fn tokenize(body: &str) -> impl Iterator + Send + '_ { + body.split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .filter(|word| word.len() <= 50) + .map(str::to_lowercase) +} From 0426f92ac032f03b8a4c86acec00b53c093a82d5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 25 Oct 2024 18:25:06 +0000 Subject: [PATCH 130/245] unify database record separator constants Signed-off-by: Jason Volk --- src/database/de.rs | 3 +-- src/database/mod.rs | 2 +- src/database/ser.rs | 5 ++++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/database/de.rs b/src/database/de.rs index e5fdf7cb2..0e074fdba 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -41,8 +41,7 @@ pub struct Ignore; pub struct IgnoreAll; impl<'de> Deserializer<'de> { - /// Record separator; an intentionally invalid-utf8 byte. - const SEP: u8 = b'\xFF'; + const SEP: u8 = crate::ser::SEP; /// Determine if the input was fully consumed and error if bytes remaining. /// This is intended for debug assertions; not optimized for parsing logic. diff --git a/src/database/mod.rs b/src/database/mod.rs index 6d3b2079b..dcd66a1ee 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -29,7 +29,7 @@ pub use self::{ handle::Handle, keyval::{KeyVal, Slice}, map::Map, - ser::{serialize, serialize_to_array, serialize_to_vec, Interfix, Json, Separator}, + ser::{serialize, serialize_to_array, serialize_to_vec, Interfix, Json, Separator, SEP}, }; conduit::mod_ctor! {} diff --git a/src/database/ser.rs b/src/database/ser.rs index 742f1e345..0cc5c886c 100644 --- a/src/database/ser.rs +++ b/src/database/ser.rs @@ -69,8 +69,11 @@ pub struct Interfix; #[derive(Debug, Serialize)] pub struct Separator; +/// Record separator; an intentionally invalid-utf8 byte. +pub const SEP: u8 = b'\xFF'; + impl Serializer<'_, W> { - const SEP: &'static [u8] = b"\xFF"; + const SEP: &'static [u8] = &[SEP]; fn tuple_start(&mut self) { debug_assert!(!self.sep, "Tuple start with separator set"); From 1e7207c23015f5bb8d9c22db30ccbb3669a9540a Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 25 Oct 2024 19:53:08 +0000 Subject: [PATCH 131/245] start an ArrayVec extension trait Signed-off-by: Jason Volk --- src/core/utils/arrayvec.rs | 15 +++++++++++++++ src/core/utils/mod.rs | 2 ++ 2 files changed, 17 insertions(+) create mode 100644 src/core/utils/arrayvec.rs diff --git a/src/core/utils/arrayvec.rs b/src/core/utils/arrayvec.rs new file mode 100644 index 000000000..685aaf18c --- /dev/null +++ b/src/core/utils/arrayvec.rs @@ -0,0 +1,15 @@ +use ::arrayvec::ArrayVec; + +pub trait ArrayVecExt { + fn extend_from_slice(&mut self, other: &[T]) -> &mut Self; +} + +impl ArrayVecExt for ArrayVec { + #[inline] + fn extend_from_slice(&mut self, other: &[T]) -> &mut Self { + self.try_extend_from_slice(other) + .expect("Insufficient buffer capacity to extend from slice"); + + self + } +} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 8e29c608f..26b0484e0 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod arrayvec; pub mod bool; pub mod bytes; pub mod content_disposition; @@ -22,6 +23,7 @@ pub use ::conduit_macros::implement; pub use ::ctor::{ctor, dtor}; pub use self::{ + arrayvec::ArrayVecExt, bool::BoolExt, bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}, debug::slice_truncated as debug_slice_truncated, From f245389c0223ed96542969dafc90f0aeab1da9f5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 26 Oct 2024 22:20:16 +0000 Subject: [PATCH 132/245] add typedef for pdu_ids Signed-off-by: Jason Volk --- src/service/rooms/short/mod.rs | 4 ++++ src/service/rooms/timeline/mod.rs | 6 +++++- src/service/rooms/timeline/pduid.rs | 13 +++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 src/service/rooms/timeline/pduid.rs diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 609c0e07e..02c449cc3 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -24,6 +24,10 @@ struct Services { globals: Dep, } +pub type ShortEventId = ShortId; +pub type ShortRoomId = ShortId; +pub type ShortId = u64; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 902e50fff..e45bf7e52 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,4 +1,5 @@ mod data; +mod pduid; use std::{ cmp, @@ -38,7 +39,10 @@ use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use self::data::Data; -pub use self::data::PdusIterItem; +pub use self::{ + data::PdusIterItem, + pduid::{PduId, RawPduId}, +}; use crate::{ account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, rooms::state_compressor::CompressedStateEvent, sending, server_keys, users, Dep, diff --git a/src/service/rooms/timeline/pduid.rs b/src/service/rooms/timeline/pduid.rs new file mode 100644 index 000000000..b43c382cf --- /dev/null +++ b/src/service/rooms/timeline/pduid.rs @@ -0,0 +1,13 @@ +use crate::rooms::short::{ShortEventId, ShortRoomId}; + +#[derive(Clone, Copy)] +pub struct PduId { + _room_id: ShortRoomId, + _event_id: ShortEventId, +} + +pub type RawPduId = [u8; PduId::LEN]; + +impl PduId { + pub const LEN: usize = size_of::() + size_of::(); +} From 21a67513f2480e6cb1cb0322e15016ba8d919dac Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 26 Oct 2024 22:21:23 +0000 Subject: [PATCH 133/245] refactor search system Signed-off-by: Jason Volk --- Cargo.lock | 1 + src/api/client/search.rs | 333 +++++++++++++++++--------------- src/service/Cargo.toml | 1 + src/service/rooms/search/mod.rs | 178 +++++++++++++---- 4 files changed, 310 insertions(+), 203 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c64d3cc67..a8acce7d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -786,6 +786,7 @@ dependencies = [ name = "conduit_service" version = "0.5.0" dependencies = [ + "arrayvec", "async-trait", "base64 0.22.1", "bytes", diff --git a/src/api/client/search.rs b/src/api/client/search.rs index b073640e8..1e5384fe2 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -2,25 +2,32 @@ use std::collections::BTreeMap; use axum::extract::State; use conduit::{ - debug, - utils::{IterStream, ReadyExt}, - Err, + at, is_true, + result::FlatOk, + utils::{stream::ReadyExt, IterStream}, + Err, PduEvent, Result, }; -use futures::{FutureExt, StreamExt}; +use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt}; use ruma::{ - api::client::{ - error::ErrorKind, - search::search_events::{ - self, - v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, - }, + api::client::search::search_events::{ + self, + v3::{Criteria, EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, }, events::AnyStateEvent, serde::Raw, - uint, OwnedRoomId, + OwnedRoomId, RoomId, UInt, UserId, }; +use search_events::v3::{Request, Response}; +use service::{rooms::search::RoomQuery, Services}; + +use crate::Ruma; -use crate::{Error, Result, Ruma}; +type RoomStates = BTreeMap; +type RoomState = Vec>; + +const LIMIT_DEFAULT: usize = 10; +const LIMIT_MAX: usize = 100; +const BATCH_MAX: usize = 20; /// # `POST /_matrix/client/r0/search` /// @@ -28,173 +35,177 @@ use crate::{Error, Result, Ruma}; /// /// - Only works if the user is currently joined to the room (TODO: Respect /// history visibility) -pub(crate) async fn search_events_route( - State(services): State, body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let search_criteria = body.search_categories.room_events.as_ref().unwrap(); - let filter = &search_criteria.filter; - let include_state = &search_criteria.include_state; - - let room_ids = if let Some(room_ids) = &filter.rooms { - room_ids.clone() - } else { - services - .rooms - .state_cache - .rooms_joined(sender_user) - .map(ToOwned::to_owned) - .collect() - .await - }; - - // Use limit or else 10, with maximum 100 - let limit: usize = filter - .limit - .unwrap_or_else(|| uint!(10)) - .try_into() - .unwrap_or(10) - .min(100); +pub(crate) async fn search_events_route(State(services): State, body: Ruma) -> Result { + let sender_user = body.sender_user(); + let next_batch = body.next_batch.as_deref(); + let room_events_result: OptionFuture<_> = body + .search_categories + .room_events + .as_ref() + .map(|criteria| category_room_events(&services, sender_user, next_batch, criteria)) + .into(); + + Ok(Response { + search_categories: ResultCategories { + room_events: room_events_result + .await + .unwrap_or_else(|| Ok(ResultRoomEvents::default()))?, + }, + }) +} - let mut room_states: BTreeMap>> = BTreeMap::new(); +#[allow(clippy::map_unwrap_or)] +async fn category_room_events( + services: &Services, sender_user: &UserId, next_batch: Option<&str>, criteria: &Criteria, +) -> Result { + let filter = &criteria.filter; - if include_state.is_some_and(|include_state| include_state) { - for room_id in &room_ids { - if !services + let limit: usize = filter + .limit + .map(TryInto::try_into) + .flat_ok() + .unwrap_or(LIMIT_DEFAULT) + .min(LIMIT_MAX); + + let next_batch: usize = next_batch + .map(str::parse) + .transpose()? + .unwrap_or(0) + .min(limit.saturating_mul(BATCH_MAX)); + + let rooms = filter + .rooms + .clone() + .map(IntoIterator::into_iter) + .map(IterStream::stream) + .map(StreamExt::boxed) + .unwrap_or_else(|| { + services .rooms .state_cache - .is_joined(sender_user, room_id) + .rooms_joined(sender_user) + .map(ToOwned::to_owned) + .boxed() + }); + + let results: Vec<_> = rooms + .filter_map(|room_id| async move { + check_room_visible(services, sender_user, &room_id, criteria) .await - { - return Err!(Request(Forbidden("You don't have permission to view this room."))); - } - - // check if sender_user can see state events - if services - .rooms - .state_accessor - .user_can_see_state_events(sender_user, room_id) + .is_ok() + .then_some(room_id) + }) + .filter_map(|room_id| async move { + let query = RoomQuery { + room_id: &room_id, + user_id: Some(sender_user), + criteria, + skip: next_batch, + limit, + }; + + let (count, results) = services.rooms.search.search_pdus(&query).await.ok()?; + + results + .collect::>() + .map(|results| (room_id.clone(), count, results)) + .map(Some) .await - { - let room_state: Vec<_> = services - .rooms - .state_accessor - .room_state_full(room_id) - .await? - .values() - .map(|pdu| pdu.to_state_event()) - .collect(); - - debug!("Room state: {:?}", room_state); - - room_states.insert(room_id.clone(), room_state); - } else { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - } - } - - let mut search_vecs = Vec::new(); - - for room_id in &room_ids { - if !services - .rooms - .state_cache - .is_joined(sender_user, room_id) - .await - { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - - if let Some(search) = services - .rooms - .search - .search_pdus(room_id, &search_criteria.search_term) - .await - { - search_vecs.push(search.0); - } - } + }) + .collect() + .await; - let mut searches: Vec<_> = search_vecs + let total: UInt = results .iter() - .map(|vec| vec.iter().peekable()) - .collect(); + .fold(0, |a: usize, (_, count, _)| a.saturating_add(*count)) + .try_into()?; - let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) { - Some(Ok(s)) => s, - Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), - None => 0, // Default to the start - }; - - let mut results = Vec::new(); - let next_batch = skip.saturating_add(limit); - - for _ in 0..next_batch { - if let Some(s) = searches - .iter_mut() - .map(|s| (s.peek().copied(), s)) - .max_by_key(|(peek, _)| *peek) - .and_then(|(_, i)| i.next()) - { - results.push(s); - } - } - - let results: Vec<_> = results - .into_iter() - .skip(skip) + let state: RoomStates = results + .iter() .stream() - .filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok)) - .ready_filter(|pdu| !pdu.is_redacted()) - .filter_map(|pdu| async move { - services - .rooms - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .ready_filter(|_| criteria.include_state.is_some_and(is_true!())) + .filter_map(|(room_id, ..)| async move { + procure_room_state(services, room_id) + .map_ok(|state| (room_id.clone(), state)) .await - .then_some(pdu) + .ok() }) - .take(limit) + .collect() + .await; + + let results: Vec = results + .into_iter() + .map(at!(2)) + .flatten() + .stream() .map(|pdu| pdu.to_room_event()) .map(|result| SearchResult { - context: EventContextResult { - end: None, - events_after: Vec::new(), - events_before: Vec::new(), - profile_info: BTreeMap::new(), - start: None, - }, rank: None, result: Some(result), + context: EventContextResult { + profile_info: BTreeMap::new(), //TODO + events_after: Vec::new(), //TODO + events_before: Vec::new(), //TODO + start: None, //TODO + end: None, //TODO + }, }) .collect() - .boxed() .await; - let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some()); - - let next_batch = more_unloaded_results.then(|| next_batch.to_string()); - - Ok(search_events::v3::Response::new(ResultCategories { - room_events: ResultRoomEvents { - count: Some(results.len().try_into().unwrap_or_else(|_| uint!(0))), - groups: BTreeMap::new(), // TODO - next_batch, - results, - state: room_states, - highlights: search_criteria - .search_term - .split_terminator(|c: char| !c.is_alphanumeric()) - .map(str::to_lowercase) - .collect(), - }, - })) + let highlights = criteria + .search_term + .split_terminator(|c: char| !c.is_alphanumeric()) + .map(str::to_lowercase) + .collect(); + + let next_batch = (results.len() >= limit) + .then_some(next_batch.saturating_add(results.len())) + .as_ref() + .map(ToString::to_string); + + Ok(ResultRoomEvents { + count: Some(total), + next_batch, + results, + state, + highlights, + groups: BTreeMap::new(), // TODO + }) +} + +async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result { + let state_map = services + .rooms + .state_accessor + .room_state_full(room_id) + .await?; + + let state_events = state_map + .values() + .map(AsRef::as_ref) + .map(PduEvent::to_state_event) + .collect(); + + Ok(state_events) +} + +async fn check_room_visible(services: &Services, user_id: &UserId, room_id: &RoomId, search: &Criteria) -> Result { + let check_visible = search.filter.rooms.is_some(); + let check_state = check_visible && search.include_state.is_some_and(is_true!()); + + let is_joined = !check_visible || services.rooms.state_cache.is_joined(user_id, room_id).await; + + let state_visible = !check_state + || services + .rooms + .state_accessor + .user_can_see_state_events(user_id, room_id) + .await; + + if !is_joined || !state_visible { + return Err!(Request(Forbidden("You don't have permission to view {room_id:?}"))); + } + + Ok(()) } diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index 737a70399..7578ef64f 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -40,6 +40,7 @@ release_max_log_level = [ ] [dependencies] +arrayvec.workspace = true async-trait.workspace = true base64.workspace = true bytes.workspace = true diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 032ad55cc..8882ec994 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,15 +1,23 @@ -use std::sync::Arc; +use std::{iter, sync::Arc}; +use arrayvec::ArrayVec; use conduit::{ implement, - utils::{set, stream::TryIgnore, IterStream, ReadyExt}, - Result, + utils::{set, stream::TryIgnore, ArrayVecExt, IterStream, ReadyExt}, + PduEvent, Result, +}; +use database::{keyval::Val, Map}; +use futures::{Stream, StreamExt}; +use ruma::{api::client::search::search_events::v3::Criteria, RoomId, UserId}; + +use crate::{ + rooms, + rooms::{ + short::ShortRoomId, + timeline::{PduId, RawPduId}, + }, + Dep, }; -use database::Map; -use futures::StreamExt; -use ruma::RoomId; - -use crate::{rooms, Dep}; pub struct Service { db: Data, @@ -22,8 +30,24 @@ struct Data { struct Services { short: Dep, + state_accessor: Dep, + timeline: Dep, } +#[derive(Clone, Debug)] +pub struct RoomQuery<'a> { + pub room_id: &'a RoomId, + pub user_id: Option<&'a UserId>, + pub criteria: &'a Criteria, + pub limit: usize, + pub skip: usize, +} + +type TokenId = ArrayVec; + +const TOKEN_ID_MAX_LEN: usize = size_of::() + WORD_MAX_LEN + 1 + size_of::(); +const WORD_MAX_LEN: usize = 50; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -32,6 +56,8 @@ impl crate::Service for Service { }, services: Services { short: args.depend::("rooms::short"), + state_accessor: args.depend::("rooms::state_accessor"), + timeline: args.depend::("rooms::timeline"), }, })) } @@ -70,46 +96,92 @@ pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { } #[implement(Service)] -pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec>, Vec)> { - let prefix = self - .services - .short - .get_shortroomid(room_id) - .await - .ok()? - .to_be_bytes() - .to_vec(); +pub async fn search_pdus<'a>( + &'a self, query: &'a RoomQuery<'a>, +) -> Result<(usize, impl Stream + Send + 'a)> { + let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await; - let words: Vec<_> = tokenize(search_string).collect(); - - let bufs: Vec<_> = words - .clone() + let count = pdu_ids.len(); + let pdus = pdu_ids .into_iter() .stream() - .then(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.db.tokenids - .rev_raw_keys_from(&last_possible_id) // Newest pdus first - .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix2)) - .map(move |key| key[prefix3.len()..].to_vec()) - .collect::>() + .filter_map(move |result_pdu_id: RawPduId| async move { + self.services + .timeline + .get_pdu_from_id(&result_pdu_id) + .await + .ok() }) - .collect() - .await; + .ready_filter(|pdu| !pdu.is_redacted()) + .filter_map(move |pdu| async move { + self.services + .state_accessor + .user_can_see_event(query.user_id?, &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) + }) + .skip(query.skip) + .take(query.limit); + + Ok((count, pdus)) +} + +// result is modeled as a stream such that callers don't have to be refactored +// though an additional async/wrap still exists for now +#[implement(Service)] +pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result + Send + '_> { + let shortroomid = self.services.short.get_shortroomid(query.room_id).await?; - let bufs = bufs.iter().map(|buf| buf.iter()); + let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await; - let results = set::intersection(bufs).cloned().collect(); + let iters = pdu_ids.into_iter().map(IntoIterator::into_iter); - Some((results, words)) + Ok(set::intersection(iters).stream()) +} + +#[implement(Service)] +async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: ShortRoomId) -> Vec> { + tokenize(&query.criteria.search_term) + .stream() + .then(|word| async move { + self.search_pdu_ids_query_words(shortroomid, &word) + .collect::>() + .await + }) + .collect::>() + .await +} + +/// Iterate over PduId's containing a word +#[implement(Service)] +fn search_pdu_ids_query_words<'a>( + &'a self, shortroomid: ShortRoomId, word: &'a str, +) -> impl Stream + Send + '_ { + self.search_pdu_ids_query_word(shortroomid, word) + .ready_filter_map(move |key| { + key[prefix_len(word)..] + .chunks_exact(PduId::LEN) + .next() + .map(RawPduId::try_from) + .and_then(Result::ok) + }) +} + +/// Iterate over raw database results for a word +#[implement(Service)] +fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream> + Send + '_ { + const PDUID_LEN: usize = PduId::LEN; + // rustc says const'ing this not yet stable + let end_id: ArrayVec = iter::repeat(u8::MAX).take(PduId::LEN).collect(); + + // Newest pdus first + let end = make_tokenid(shortroomid, word, end_id.as_slice()); + let prefix = make_prefix(shortroomid, word); + self.db + .tokenids + .rev_raw_keys_from(&end) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) } /// Splits a string into tokens used as keys in the search inverted index @@ -119,6 +191,28 @@ pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option fn tokenize(body: &str) -> impl Iterator + Send + '_ { body.split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) + .filter(|word| word.len() <= WORD_MAX_LEN) .map(str::to_lowercase) } + +fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &[u8]) -> TokenId { + debug_assert!(pdu_id.len() == PduId::LEN, "pdu_id size mismatch"); + + let mut key = make_prefix(shortroomid, word); + key.extend_from_slice(pdu_id); + key +} + +fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId { + let mut key = TokenId::new(); + key.extend_from_slice(&shortroomid.to_be_bytes()); + key.extend_from_slice(word.as_bytes()); + key.push(database::SEP); + key +} + +fn prefix_len(word: &str) -> usize { + size_of::() + .saturating_add(word.len()) + .saturating_add(1) +} From d281b8d3ae1818ea84be11ba38ac0325aaa84ffc Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 26 Oct 2024 22:22:10 +0000 Subject: [PATCH 134/245] implement filters for search (#596) closes #596 Signed-off-by: Jason Volk --- src/service/rooms/search/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 8882ec994..70daded1e 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -113,6 +113,7 @@ pub async fn search_pdus<'a>( .ok() }) .ready_filter(|pdu| !pdu.is_redacted()) + .ready_filter(|pdu| pdu.matches(&query.criteria.filter)) .filter_map(move |pdu| async move { self.services .state_accessor From 5e6dbaa27f5e08556422ee6b756efdc318654fd7 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 27 Oct 2024 01:48:57 +0000 Subject: [PATCH 135/245] apply room event filter to messages endpoint (#596) Signed-off-by: Jason Volk --- src/api/client/message.rs | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 578b675b5..094daa306 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -9,13 +9,13 @@ use conduit::{ use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ - filter::{RoomEventFilter, UrlFilter}, + filter::RoomEventFilter, message::{get_message_events, send_message_event}, }, events::{MessageLikeEventType, StateEventType, TimelineEventType::*}, UserId, }; -use serde_json::{from_str, Value}; +use serde_json::from_str; use service::rooms::timeline::PdusIterItem; use crate::{ @@ -151,7 +151,7 @@ pub(crate) async fn get_message_events_route( .timeline .pdus_after(sender_user, room_id, from) .await? - .ready_filter_map(|item| contains_url_filter(item, filter)) + .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| visibility_filter(&services, item, sender_user)) .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) @@ -225,7 +225,7 @@ pub(crate) async fn get_message_events_route( .timeline .pdus_until(sender_user, room_id, from) .await? - .ready_filter_map(|item| contains_url_filter(item, filter)) + .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|(count, pdu)| async move { // list of safe and common non-state events to ignore if matches!( @@ -329,19 +329,7 @@ async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &Us .then_some(item) } -fn contains_url_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { +fn event_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { let (_, pdu) = &item; - - if filter.url_filter.is_none() { - return Some(item); - } - - let content: Value = from_str(pdu.content.get()).unwrap(); - let res = match filter.url_filter { - Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(), - Some(UrlFilter::EventsWithUrl) => content["url"].is_string(), - None => true, - }; - - res.then_some(item) + pdu.matches(filter).then_some(item) } From 9787dfe77c7de1ef186c8bb934ba242f856ccc12 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 27 Oct 2024 00:30:30 +0000 Subject: [PATCH 136/245] fix clippy::ref_option fix needless borrow fix clippy::nonminimal_bool --- Cargo.toml | 1 + src/api/client/config.rs | 6 +++--- src/api/client/report.rs | 4 ++-- src/api/client/room.rs | 12 ++++++------ src/api/client/sync/v4.rs | 2 +- src/core/utils/future/try_ext_ext.rs | 1 + src/core/utils/stream/ready.rs | 1 + src/core/utils/stream/try_ready.rs | 1 + src/macros/config.rs | 18 ++++-------------- src/service/admin/mod.rs | 4 ++-- src/service/services.rs | 4 ++-- 11 files changed, 24 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 73f16daf0..2f9f196b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -772,6 +772,7 @@ unused-qualifications = "warn" #unused-results = "warn" # TODO ## some sadness +elided_named_lifetimes = "allow" # TODO! let_underscore_drop = "allow" missing_docs = "allow" # cfgs cannot be limited to expected cfgs or their de facto non-transitive/opt-in use-case e.g. diff --git a/src/api/client/config.rs b/src/api/client/config.rs index d06cc0729..3cf711353 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -23,7 +23,7 @@ pub(crate) async fn set_global_account_data_route( set_account_data( &services, None, - &body.sender_user, + body.sender_user.as_ref(), &body.event_type.to_string(), body.data.json(), ) @@ -41,7 +41,7 @@ pub(crate) async fn set_room_account_data_route( set_account_data( &services, Some(&body.room_id), - &body.sender_user, + body.sender_user.as_ref(), &body.event_type.to_string(), body.data.json(), ) @@ -89,7 +89,7 @@ pub(crate) async fn get_room_account_data_route( } async fn set_account_data( - services: &Services, room_id: Option<&RoomId>, sender_user: &Option, event_type: &str, + services: &Services, room_id: Option<&RoomId>, sender_user: Option<&OwnedUserId>, event_type: &str, data: &RawJsonValue, ) -> Result<()> { let sender_user = sender_user.as_ref().expect("user is authenticated"); diff --git a/src/api/client/report.rs b/src/api/client/report.rs index cf789246a..143c13e56 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -101,7 +101,7 @@ pub(crate) async fn report_event_route( &pdu.event_id, &body.room_id, sender_user, - &body.reason, + body.reason.as_ref(), body.score, &pdu, ) @@ -134,7 +134,7 @@ pub(crate) async fn report_event_route( /// check if report reasoning is less than or equal to 750 characters /// check if reporting user is in the reporting room async fn is_event_report_valid( - services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option, + services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: Option<&String>, score: Option, pdu: &std::sync::Arc, ) -> Result<()> { debug_info!("Checking if report from user {sender_user} for event {event_id} in room {room_id} is valid"); diff --git a/src/api/client/room.rs b/src/api/client/room.rs index daadb7242..4224d3fa7 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -126,8 +126,8 @@ pub(crate) async fn create_room_route( .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; - let alias: Option = if let Some(alias) = &body.room_alias_name { - Some(room_alias_check(&services, alias, &body.appservice_info).await?) + let alias: Option = if let Some(alias) = body.room_alias_name.as_ref() { + Some(room_alias_check(&services, alias, body.appservice_info.as_ref()).await?) } else { None }; @@ -270,7 +270,7 @@ pub(crate) async fn create_room_route( } let power_levels_content = - default_power_levels_content(&body.power_level_content_override, &body.visibility, users)?; + default_power_levels_content(body.power_level_content_override.as_ref(), &body.visibility, users)?; services .rooms @@ -814,7 +814,7 @@ pub(crate) async fn upgrade_room_route( /// creates the power_levels_content for the PDU builder fn default_power_levels_content( - power_level_content_override: &Option>, visibility: &room::Visibility, + power_level_content_override: Option<&Raw>, visibility: &room::Visibility, users: BTreeMap, ) -> Result { let mut power_levels_content = serde_json::to_value(RoomPowerLevelsEventContent { @@ -864,7 +864,7 @@ fn default_power_levels_content( /// if a room is being created with a room alias, run our checks async fn room_alias_check( - services: &Services, room_alias_name: &str, appservice_info: &Option, + services: &Services, room_alias_name: &str, appservice_info: Option<&RegistrationInfo>, ) -> Result { // Basic checks on the room alias validity if room_alias_name.contains(':') { @@ -905,7 +905,7 @@ async fn room_alias_check( return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")); } - if let Some(ref info) = appservice_info { + if let Some(info) = appservice_info { if !info.aliases.is_match(full_room_alias.as_str()) { return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); } diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 4f8323e66..f8ada81c9 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -560,7 +560,7 @@ pub(crate) async fn sync_events_v4_route( for (_, pdu) in timeline_pdus { let ts = MilliSecondsSinceUnixEpoch(pdu.origin_server_ts); - if DEFAULT_BUMP_TYPES.contains(pdu.event_type()) && !timestamp.is_some_and(|time| time > ts) { + if DEFAULT_BUMP_TYPES.contains(pdu.event_type()) && timestamp.is_none_or(|time| time <= ts) { timestamp = Some(ts); } } diff --git a/src/core/utils/future/try_ext_ext.rs b/src/core/utils/future/try_ext_ext.rs index 7c0b36a28..f97ae8852 100644 --- a/src/core/utils/future/try_ext_ext.rs +++ b/src/core/utils/future/try_ext_ext.rs @@ -1,4 +1,5 @@ //! Extended external extensions to futures::TryFutureExt +#![allow(clippy::type_complexity)] use futures::{ future::{MapOkOrElse, UnwrapOrElse}, diff --git a/src/core/utils/stream/ready.rs b/src/core/utils/stream/ready.rs index da5aec5a6..c16d12465 100644 --- a/src/core/utils/stream/ready.rs +++ b/src/core/utils/stream/ready.rs @@ -1,4 +1,5 @@ //! Synchronous combinator extensions to futures::Stream +#![allow(clippy::type_complexity)] use futures::{ future::{ready, Ready}, diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs index df3564565..feb380675 100644 --- a/src/core/utils/stream/try_ready.rs +++ b/src/core/utils/stream/try_ready.rs @@ -1,4 +1,5 @@ //! Synchronous combinator extensions to futures::TryStream +#![allow(clippy::type_complexity)] use futures::{ future::{ready, Ready}, diff --git a/src/macros/config.rs b/src/macros/config.rs index f86163520..6ccdb73cd 100644 --- a/src/macros/config.rs +++ b/src/macros/config.rs @@ -164,11 +164,11 @@ fn get_default(field: &Field) -> Option { continue; }; - if !path + if path .segments .iter() .next() - .is_some_and(|s| s.ident == "serde") + .is_none_or(|s| s.ident == "serde") { continue; } @@ -218,12 +218,7 @@ fn get_doc_default(field: &Field) -> Option { continue; }; - if !path - .segments - .iter() - .next() - .is_some_and(|s| s.ident == "doc") - { + if path.segments.iter().next().is_none_or(|s| s.ident == "doc") { continue; } @@ -266,12 +261,7 @@ fn get_doc_comment(field: &Field) -> Option { continue; }; - if !path - .segments - .iter() - .next() - .is_some_and(|s| s.ident == "doc") - { + if path.segments.iter().next().is_none_or(|s| s.ident == "doc") { continue; } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 58cc012c2..2860bd1bb 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -370,9 +370,9 @@ impl Service { /// Sets the self-reference to crate::Services which will provide context to /// the admin commands. - pub(super) fn set_services(&self, services: &Option>) { + pub(super) fn set_services(&self, services: Option<&Arc>) { let receiver = &mut *self.services.services.write().expect("locked for writing"); - let weak = services.as_ref().map(Arc::downgrade); + let weak = services.map(Arc::downgrade); *receiver = weak; } } diff --git a/src/service/services.rs b/src/service/services.rs index ea81f434f..c0af42499 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -113,7 +113,7 @@ impl Services { pub async fn start(self: &Arc) -> Result> { debug_info!("Starting services..."); - self.admin.set_services(&Some(Arc::clone(self))); + self.admin.set_services(Some(Arc::clone(self)).as_ref()); globals::migrations::migrations(self).await?; self.manager .lock() @@ -151,7 +151,7 @@ impl Services { manager.stop().await; } - self.admin.set_services(&None); + self.admin.set_services(None); debug_info!("Services shutdown complete."); } From e7e606300f33410bfb6bfdf7c9671b210e37f287 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 27 Oct 2024 19:17:41 +0000 Subject: [PATCH 137/245] slightly simplify reqwest/hickory hooks Signed-off-by: Jason Volk --- src/service/resolver/dns.rs | 59 ++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index b77bbb84f..89129e03e 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -1,15 +1,11 @@ -use std::{ - future, iter, - net::{IpAddr, SocketAddr}, - sync::Arc, - time::Duration, -}; +use std::{iter, net::SocketAddr, sync::Arc, time::Duration}; use conduit::{err, Result, Server}; +use futures::FutureExt; use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; -use super::cache::Cache; +use super::cache::{Cache, CachedOverride}; pub struct Resolver { pub(crate) resolver: Arc, @@ -21,6 +17,8 @@ pub(crate) struct Hooked { cache: Arc, } +type ResolvingResult = Result>; + impl Resolver { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] pub(super) fn build(server: &Arc, cache: Arc) -> Result> { @@ -82,12 +80,12 @@ impl Resolver { } impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name).boxed() } } impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { - let cached = self + let cached: Option = self .cache .overrides .read() @@ -95,35 +93,30 @@ impl Resolve for Hooked { .get(name.as_str()) .cloned(); - if let Some(cached) = cached { - cached_to_reqwest(&cached.ips, cached.port) - } else { - resolve_to_reqwest(self.resolver.clone(), name) - } + cached.map_or_else( + || resolve_to_reqwest(self.resolver.clone(), name).boxed(), + |cached| cached_to_reqwest(cached).boxed(), + ) } } -fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { - override_name +async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult { + let first_ip = cached + .ips .first() - .map(|first_name| -> Resolving { - let saddr = SocketAddr::new(*first_name, port); - let result: Box + Send> = Box::new(iter::once(saddr)); - Box::pin(future::ready(Ok(result))) - }) - .expect("must provide at least one override name") -} + .expect("must provide at least one override"); -fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { - Box::pin(async move { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); + let saddr = SocketAddr::new(*first_ip, cached.port); + + Ok(Box::new(iter::once(saddr))) +} - let results: Addrs = Box::new(results); +async fn resolve_to_reqwest(resolver: Arc, name: Name) -> ResolvingResult { + let results = resolver + .lookup_ip(name.as_str()) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, 0)); - Ok(results) - }) + Ok(Box::new(results)) } From 6c9ecb031a62db0c589f383f7effe01ea30f38ce Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 27 Oct 2024 20:53:22 +0000 Subject: [PATCH 138/245] re-export ruma Event trait through core pdu Signed-off-by: Jason Volk --- src/core/mod.rs | 2 +- src/core/pdu/{state_res.rs => event.rs} | 5 +++-- src/core/pdu/mod.rs | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) rename src/core/pdu/{state_res.rs => event.rs} (85%) diff --git a/src/core/mod.rs b/src/core/mod.rs index d201709bd..1b7b8fa13 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -17,7 +17,7 @@ pub use ::tracing; pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; -pub use pdu::{PduBuilder, PduCount, PduEvent}; +pub use pdu::{Event, PduBuilder, PduCount, PduEvent}; pub use server::Server; pub use utils::{ctor, dtor, implement, result, result::Result}; diff --git a/src/core/pdu/state_res.rs b/src/core/pdu/event.rs similarity index 85% rename from src/core/pdu/state_res.rs rename to src/core/pdu/event.rs index a27c98229..15117f925 100644 --- a/src/core/pdu/state_res.rs +++ b/src/core/pdu/event.rs @@ -1,11 +1,12 @@ use std::sync::Arc; -use ruma::{events::TimelineEventType, state_res, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId}; +pub use ruma::state_res::Event; +use ruma::{events::TimelineEventType, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId}; use serde_json::value::RawValue as RawJsonValue; use super::PduEvent; -impl state_res::Event for PduEvent { +impl Event for PduEvent { type Id = Arc; fn event_id(&self) -> &Self::Id { &self.event_id } diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index ed11adbb2..9c3aaf9b6 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -1,10 +1,10 @@ mod builder; mod content; mod count; +mod event; mod filter; mod id; mod redact; -mod state_res; mod strip; mod unsigned; @@ -19,6 +19,7 @@ use serde_json::value::RawValue as RawJsonValue; pub use self::{ builder::{Builder, Builder as PduBuilder}, count::PduCount, + event::Event, id::*, }; use crate::Result; From 7a09ac81e039a5ac1dc6d7e215824599b00aed36 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 27 Oct 2024 20:13:10 +0000 Subject: [PATCH 139/245] split send from messages; refactor client/messages; add filters to client/context Signed-off-by: Jason Volk --- src/api/client/context.rs | 264 +++++++++------------- src/api/client/message.rs | 459 +++++++++++++++----------------------- src/api/client/mod.rs | 2 + src/api/client/send.rs | 92 ++++++++ src/api/router/args.rs | 26 ++- 5 files changed, 404 insertions(+), 439 deletions(-) create mode 100644 src/api/client/send.rs diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 9a5c4e826..9bf0c4670 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,14 +1,25 @@ -use std::collections::HashSet; +use std::iter::once; use axum::extract::State; -use conduit::{err, error, Err}; -use futures::StreamExt; +use conduit::{ + err, error, + utils::{future::TryExtExt, stream::ReadyExt, IterStream}, + Err, Result, +}; +use futures::{future::try_join, StreamExt, TryFutureExt}; use ruma::{ api::client::{context::get_context, filter::LazyLoadOptions}, - events::{StateEventType, TimelineEventType::*}, + events::StateEventType, + UserId, +}; + +use crate::{ + client::message::{event_filter, ignored_filter, update_lazy, visibility_filter, LazySet}, + Ruma, }; -use crate::{Result, Ruma}; +const LIMIT_MAX: usize = 100; +const LIMIT_DEFAULT: usize = 10; /// # `GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}` /// @@ -19,33 +30,43 @@ use crate::{Result, Ruma}; pub(crate) async fn get_context_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let filter = &body.filter; + let sender = body.sender(); + let (sender_user, _) = sender; + + // Use limit or else 10, with maximum 100 + let limit: usize = body + .limit + .try_into() + .unwrap_or(LIMIT_DEFAULT) + .min(LIMIT_MAX); // some clients, at least element, seem to require knowledge of redundant // members for "inline" profiles on the timeline to work properly - let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options { - LazyLoadOptions::Enabled { - include_redundant_members, - } => (true, *include_redundant_members), - LazyLoadOptions::Disabled => (false, cfg!(feature = "element_hacks")), - }; + let lazy_load_enabled = matches!(filter.lazy_load_options, LazyLoadOptions::Enabled { .. }); - let mut lazy_loaded = HashSet::with_capacity(100); + let lazy_load_redundant = if let LazyLoadOptions::Enabled { + include_redundant_members, + } = filter.lazy_load_options + { + include_redundant_members + } else { + false + }; let base_token = services .rooms .timeline .get_pdu_count(&body.event_id) - .await - .map_err(|_| err!(Request(NotFound("Base event id not found."))))?; + .map_err(|_| err!(Request(NotFound("Event not found.")))); let base_event = services .rooms .timeline .get_pdu(&body.event_id) - .await - .map_err(|_| err!(Request(NotFound("Base event not found."))))?; + .map_err(|_| err!(Request(NotFound("Base event not found.")))); + + let (base_token, base_event) = try_join(base_token, base_event).await?; let room_id = &base_event.room_id; @@ -58,136 +79,50 @@ pub(crate) async fn get_context_route( return Err!(Request(Forbidden("You don't have permission to view this event."))); } - if !services - .rooms - .lazy_loading - .lazy_load_was_sent_before(sender_user, sender_device, room_id, &base_event.sender) - .await || lazy_load_send_redundant - { - lazy_loaded.insert(base_event.sender.as_str().to_owned()); - } - - // Use limit or else 10, with maximum 100 - let limit = usize::try_from(body.limit).unwrap_or(10).min(100); - - let base_event = base_event.to_room_event(); - let events_before: Vec<_> = services .rooms .timeline .pdus_until(sender_user, room_id, base_token) .await? + .ready_filter_map(|item| event_filter(item, filter)) + .filter_map(|item| ignored_filter(&services, item, sender_user)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) .take(limit / 2) - .filter_map(|(count, pdu)| async move { - // list of safe and common non-state events to ignore - if matches!( - &pdu.kind, - RoomMessage - | Sticker | CallInvite - | CallNotify | RoomEncrypted - | Image | File | Audio - | Voice | Video | UnstablePollStart - | PollStart | KeyVerificationStart - | Reaction | Emote - | Location - ) && services - .users - .user_is_ignored(&pdu.sender, sender_user) - .await - { - return None; - } - - services - .rooms - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .await - .then_some((count, pdu)) - }) .collect() .await; - for (_, event) in &events_before { - if !services - .rooms - .lazy_loading - .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) - .await || lazy_load_send_redundant - { - lazy_loaded.insert(event.sender.as_str().to_owned()); - } - } - - let start_token = events_before - .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - let events_after: Vec<_> = services .rooms .timeline .pdus_after(sender_user, room_id, base_token) .await? + .ready_filter_map(|item| event_filter(item, filter)) + .filter_map(|item| ignored_filter(&services, item, sender_user)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) .take(limit / 2) - .filter_map(|(count, pdu)| async move { - // list of safe and common non-state events to ignore - if matches!( - &pdu.kind, - RoomMessage - | Sticker | CallInvite - | CallNotify | RoomEncrypted - | Image | File | Audio - | Voice | Video | UnstablePollStart - | PollStart | KeyVerificationStart - | Reaction | Emote - | Location - ) && services - .users - .user_is_ignored(&pdu.sender, sender_user) - .await - { - return None; - } + .collect() + .await; - services - .rooms - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .await - .then_some((count, pdu)) + let lazy = once(&(base_token, (*base_event).clone())) + .chain(events_before.iter()) + .chain(events_after.iter()) + .stream() + .fold(LazySet::new(), |lazy, item| { + update_lazy(&services, room_id, sender, lazy, item, lazy_load_redundant) }) - .collect() .await; - for (_, event) in &events_after { - if !services - .rooms - .lazy_loading - .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) - .await || lazy_load_send_redundant - { - lazy_loaded.insert(event.sender.as_str().to_owned()); - } - } + let state_id = events_after + .last() + .map_or(body.event_id.as_ref(), |(_, e)| e.event_id.as_ref()); let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash( - events_after - .last() - .map_or(&*body.event_id, |(_, e)| &*e.event_id), - ) + .pdu_shortstatehash(state_id) + .or_else(|_| services.rooms.state.get_room_shortstatehash(room_id)) .await - .map_or( - services - .rooms - .state - .get_room_shortstatehash(room_id) - .await - .expect("All rooms have state"), - |hash| hash, - ); + .map_err(|e| err!(Database("State hash not found: {e}")))?; let state_ids = services .rooms @@ -196,48 +131,61 @@ pub(crate) async fn get_context_route( .await .map_err(|e| err!(Database("State not found: {e}")))?; - let end_token = events_after - .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - - let mut state = Vec::with_capacity(state_ids.len()); - - for (shortstatekey, id) in state_ids { - let (event_type, state_key) = services - .rooms - .short - .get_statekey_from_short(shortstatekey) - .await?; - - if event_type != StateEventType::RoomMember { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); - continue; - }; - - state.push(pdu.to_state_event()); - } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); - continue; - }; - - state.push(pdu.to_state_event()); - } - } + let lazy = &lazy; + let state: Vec<_> = state_ids + .iter() + .stream() + .filter_map(|(shortstatekey, event_id)| { + services + .rooms + .short + .get_statekey_from_short(*shortstatekey) + .map_ok(move |(event_type, state_key)| (event_type, state_key, event_id)) + .ok() + }) + .filter_map(|(event_type, state_key, event_id)| async move { + if lazy_load_enabled && event_type == StateEventType::RoomMember { + let user_id: &UserId = state_key.as_str().try_into().ok()?; + if !lazy.contains(user_id) { + return None; + } + } + + services + .rooms + .timeline + .get_pdu(event_id) + .await + .inspect_err(|_| error!("Pdu in state not found: {event_id}")) + .map(|pdu| pdu.to_state_event()) + .ok() + }) + .collect() + .await; Ok(get_context::v3::Response { - start: Some(start_token), - end: Some(end_token), + event: Some(base_event.to_room_event()), + + start: events_before + .last() + .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()) + .into(), + + end: events_after + .last() + .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()) + .into(), + events_before: events_before - .iter() + .into_iter() .map(|(_, pdu)| pdu.to_room_event()) .collect(), - event: Some(base_event), + events_after: events_after - .iter() + .into_iter() .map(|(_, pdu)| pdu.to_room_event()) .collect(), + state, }) } diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 094daa306..4fc58d9f6 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,111 +1,52 @@ -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; use axum::extract::State; use conduit::{ - err, - utils::{IterStream, ReadyExt}, - Err, PduCount, + at, is_equal_to, + utils::{ + result::{FlatOk, LogErr}, + IterStream, ReadyExt, + }, + Event, PduCount, Result, }; use futures::{FutureExt, StreamExt}; use ruma::{ - api::client::{ - filter::RoomEventFilter, - message::{get_message_events, send_message_event}, + api::{ + client::{filter::RoomEventFilter, message::get_message_events}, + Direction, }, - events::{MessageLikeEventType, StateEventType, TimelineEventType::*}, - UserId, -}; -use serde_json::from_str; -use service::rooms::timeline::PdusIterItem; - -use crate::{ - service::{pdu::PduBuilder, Services}, - utils, Result, Ruma, + events::{AnyStateEvent, StateEventType, TimelineEventType, TimelineEventType::*}, + serde::Raw, + DeviceId, OwnedUserId, RoomId, UserId, }; - -/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` -/// -/// Send a message event into the room. -/// -/// - Is a NOOP if the txn id was already used before and returns the same event -/// id again -/// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is -/// allowed -pub(crate) async fn send_message_event_route( - State(services): State, body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let sender_device = body.sender_device.as_deref(); - let appservice_info = body.appservice_info.as_ref(); - - // Forbid m.room.encrypted if encryption is disabled - if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { - return Err!(Request(Forbidden("Encryption has been disabled"))); - } - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - - if body.event_type == MessageLikeEventType::CallInvite - && services.rooms.directory.is_public_room(&body.room_id).await - { - return Err!(Request(Forbidden("Room call invites are not allowed in public rooms"))); - } - - // Check if this is a new transaction id - if let Ok(response) = services - .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id) - .await - { - // The client might have sent a txnid of the /sendToDevice endpoint - // This txnid has no response associated with it - if response.is_empty() { - return Err!(Request(InvalidParam( - "Tried to use txn id already used for an incompatible endpoint." - ))); - } - - return Ok(send_message_event::v3::Response { - event_id: utils::string_from_bytes(&response) - .map(TryInto::try_into) - .map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??, - }); - } - - let mut unsigned = BTreeMap::new(); - unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - - let content = - from_str(body.body.body.json().get()).map_err(|e| err!(Request(BadJson("Invalid JSON body: {e}"))))?; - - let event_id = services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: body.event_type.clone().into(), - content, - unsigned: Some(unsigned), - timestamp: appservice_info.and(body.timestamp), - ..Default::default() - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - services - .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes()); - - drop(state_lock); - - Ok(send_message_event::v3::Response { - event_id: event_id.into(), - }) -} +use service::{rooms::timeline::PdusIterItem, Services}; + +use crate::Ruma; + +pub(crate) type LazySet = HashSet; + +/// list of safe and common non-state events to ignore +const IGNORED_MESSAGE_TYPES: &[TimelineEventType] = &[ + RoomMessage, + Sticker, + CallInvite, + CallNotify, + RoomEncrypted, + Image, + File, + Audio, + Voice, + Video, + UnstablePollStart, + PollStart, + KeyVerificationStart, + Reaction, + Emote, + Location, +]; + +const LIMIT_MAX: usize = 100; +const LIMIT_DEFAULT: usize = 10; /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// @@ -116,209 +57,171 @@ pub(crate) async fn send_message_event_route( pub(crate) async fn get_message_events_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - + let sender = body.sender(); + let (sender_user, sender_device) = sender; let room_id = &body.room_id; let filter = &body.filter; - let limit = usize::try_from(body.limit).unwrap_or(10).min(100); - let from = match body.from.as_ref() { - Some(from) => PduCount::try_from_string(from)?, - None => match body.dir { - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, + let from_default = match body.dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let from = body + .from + .as_deref() + .map(PduCount::try_from_string) + .transpose()? + .unwrap_or(from_default); + + let to = body.to.as_deref().map(PduCount::try_from_string).flat_ok(); + + let limit: usize = body + .limit + .try_into() + .unwrap_or(LIMIT_DEFAULT) + .min(LIMIT_MAX); services .rooms .lazy_loading .lazy_load_confirm_delivery(sender_user, sender_device, room_id, from); - let mut resp = get_message_events::v3::Response::new(); - let mut lazy_loaded = HashSet::new(); - let next_token; - match body.dir { - ruma::api::Direction::Forward => { - let events_after: Vec = services - .rooms - .timeline - .pdus_after(sender_user, room_id, from) - .await? - .ready_filter_map(|item| event_filter(item, filter)) - .filter_map(|item| visibility_filter(&services, item, sender_user)) - .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` - .take(limit) - .collect() - .boxed() - .await; - - for (_, event) in &events_after { - /* TODO: Remove the not "element_hacks" check when these are resolved: - * https://github.com/vector-im/element-android/issues/3417 - * https://github.com/vector-im/element-web/issues/21034 - */ - if !cfg!(feature = "element_hacks") - && !services - .rooms - .lazy_loading - .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) - .await - { - lazy_loaded.insert(event.sender.clone()); - } - - if cfg!(features = "element_hacks") { - lazy_loaded.insert(event.sender.clone()); - } - } - - next_token = events_after.last().map(|(count, _)| count).copied(); - - let events_after: Vec<_> = events_after - .into_iter() - .stream() - .filter_map(|(_, pdu)| async move { - // list of safe and common non-state events to ignore - if matches!( - &pdu.kind, - RoomMessage - | Sticker | CallInvite - | CallNotify | RoomEncrypted - | Image | File | Audio - | Voice | Video | UnstablePollStart - | PollStart | KeyVerificationStart - | Reaction | Emote | Location - ) && services - .users - .user_is_ignored(&pdu.sender, sender_user) - .await - { - return None; - } - - Some(pdu.to_room_event()) - }) - .collect() - .await; - - resp.start = from.stringify(); - resp.end = next_token.map(|count| count.stringify()); - resp.chunk = events_after; - }, - ruma::api::Direction::Backward => { - services - .rooms - .timeline - .backfill_if_required(room_id, from) - .boxed() - .await?; - - let events_before: Vec = services - .rooms - .timeline - .pdus_until(sender_user, room_id, from) - .await? - .ready_filter_map(|item| event_filter(item, filter)) - .filter_map(|(count, pdu)| async move { - // list of safe and common non-state events to ignore - if matches!( - &pdu.kind, - RoomMessage - | Sticker | CallInvite - | CallNotify | RoomEncrypted - | Image | File | Audio - | Voice | Video | UnstablePollStart - | PollStart | KeyVerificationStart - | Reaction | Emote | Location - ) && services - .users - .user_is_ignored(&pdu.sender, sender_user) - .await - { - return None; - } - - Some((count, pdu)) - }) - .filter_map(|item| visibility_filter(&services, item, sender_user)) - .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` - .take(limit) - .collect() - .boxed() - .await; - - for (_, event) in &events_before { - /* TODO: Remove the not "element_hacks" check when these are resolved: - * https://github.com/vector-im/element-android/issues/3417 - * https://github.com/vector-im/element-web/issues/21034 - */ - if !cfg!(feature = "element_hacks") - && !services - .rooms - .lazy_loading - .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) - .await - { - lazy_loaded.insert(event.sender.clone()); - } - - if cfg!(features = "element_hacks") { - lazy_loaded.insert(event.sender.clone()); - } - } - - next_token = events_before.last().map(|(count, _)| count).copied(); - - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); - - resp.start = from.stringify(); - resp.end = next_token.map(|count| count.stringify()); - resp.chunk = events_before; - }, + if matches!(body.dir, Direction::Backward) { + services + .rooms + .timeline + .backfill_if_required(room_id, from) + .boxed() + .await + .log_err() + .ok(); } - resp.state = lazy_loaded + let it = match body.dir { + Direction::Forward => services + .rooms + .timeline + .pdus_after(sender_user, room_id, from) + .await? + .boxed(), + + Direction::Backward => services + .rooms + .timeline + .pdus_until(sender_user, room_id, from) + .await? + .boxed(), + }; + + let events: Vec<_> = it + .ready_take_while(|(count, _)| Some(*count) != to) + .ready_filter_map(|item| event_filter(item, filter)) + .filter_map(|item| ignored_filter(&services, item, sender_user)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .take(limit) + .collect() + .await; + + let lazy = events .iter() .stream() - .filter_map(|ll_user_id| async move { - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, ll_user_id.as_str()) - .await - .map(|member_event| member_event.to_state_event()) - .ok() + .fold(LazySet::new(), |lazy, item| { + update_lazy(&services, room_id, sender, lazy, item, false) }) + .await; + + let state = lazy + .iter() + .stream() + .filter_map(|user_id| get_member_event(&services, room_id, user_id)) .collect() .await; - // remove the feature check when we are sure clients like element can handle it + let next_token = events.last().map(|(count, _)| count).copied(); + if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { - services.rooms.lazy_loading.lazy_load_mark_sent( - sender_user, - sender_device, - room_id, - lazy_loaded, - next_token, - ); + services + .rooms + .lazy_loading + .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy, next_token); } } - Ok(resp) + let chunk = events + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) + .collect(); + + Ok(get_message_events::v3::Response { + start: from.stringify(), + end: next_token.as_ref().map(PduCount::stringify), + chunk, + state, + }) +} + +async fn get_member_event(services: &Services, room_id: &RoomId, user_id: &UserId) -> Option> { + services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + .map(|member_event| member_event.to_state_event()) + .ok() +} + +pub(crate) async fn update_lazy( + services: &Services, room_id: &RoomId, sender: (&UserId, &DeviceId), mut lazy: LazySet, item: &PdusIterItem, + force: bool, +) -> LazySet { + let (_, event) = &item; + let (sender_user, sender_device) = sender; + + /* TODO: Remove the not "element_hacks" check when these are resolved: + * https://github.com/vector-im/element-android/issues/3417 + * https://github.com/vector-im/element-web/issues/21034 + */ + if force || cfg!(features = "element_hacks") { + lazy.insert(event.sender().into()); + return lazy; + } + + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, event.sender()) + .await + { + lazy.insert(event.sender().into()); + } + + lazy +} + +pub(crate) async fn ignored_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option { + let (_, pdu) = &item; + + if pdu.kind.to_cow_str() == "org.matrix.dummy_event" { + return None; + } + + if !IGNORED_MESSAGE_TYPES.iter().any(is_equal_to!(&pdu.kind)) { + return Some(item); + } + + if !services.users.user_is_ignored(&pdu.sender, user_id).await { + return Some(item); + } + + None } -async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option { +pub(crate) async fn visibility_filter( + services: &Services, item: PdusIterItem, user_id: &UserId, +) -> Option { let (_, pdu) = &item; services @@ -329,7 +232,7 @@ async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &Us .then_some(item) } -fn event_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { +pub(crate) fn event_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { let (_, pdu) = &item; pdu.matches(filter).then_some(item) } diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index 2928be87b..9ee88bec1 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -23,6 +23,7 @@ pub(super) mod relations; pub(super) mod report; pub(super) mod room; pub(super) mod search; +pub(super) mod send; pub(super) mod session; pub(super) mod space; pub(super) mod state; @@ -65,6 +66,7 @@ pub(super) use relations::*; pub(super) use report::*; pub(super) use room::*; pub(super) use search::*; +pub(super) use send::*; pub(super) use session::*; pub(super) use space::*; pub(super) use state::*; diff --git a/src/api/client/send.rs b/src/api/client/send.rs new file mode 100644 index 000000000..ff011efab --- /dev/null +++ b/src/api/client/send.rs @@ -0,0 +1,92 @@ +use std::collections::BTreeMap; + +use axum::extract::State; +use conduit::{err, Err}; +use ruma::{api::client::message::send_message_event, events::MessageLikeEventType}; +use serde_json::from_str; + +use crate::{service::pdu::PduBuilder, utils, Result, Ruma}; + +/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` +/// +/// Send a message event into the room. +/// +/// - Is a NOOP if the txn id was already used before and returns the same event +/// id again +/// - The only requirement for the content is that it has to be valid json +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed +pub(crate) async fn send_message_event_route( + State(services): State, body: Ruma, +) -> Result { + let sender_user = body.sender_user(); + let sender_device = body.sender_device.as_deref(); + let appservice_info = body.appservice_info.as_ref(); + + // Forbid m.room.encrypted if encryption is disabled + if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { + return Err!(Request(Forbidden("Encryption has been disabled"))); + } + + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + if body.event_type == MessageLikeEventType::CallInvite + && services.rooms.directory.is_public_room(&body.room_id).await + { + return Err!(Request(Forbidden("Room call invites are not allowed in public rooms"))); + } + + // Check if this is a new transaction id + if let Ok(response) = services + .transaction_ids + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await + { + // The client might have sent a txnid of the /sendToDevice endpoint + // This txnid has no response associated with it + if response.is_empty() { + return Err!(Request(InvalidParam( + "Tried to use txn id already used for an incompatible endpoint." + ))); + } + + return Ok(send_message_event::v3::Response { + event_id: utils::string_from_bytes(&response) + .map(TryInto::try_into) + .map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??, + }); + } + + let mut unsigned = BTreeMap::new(); + unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); + + let content = + from_str(body.body.body.json().get()).map_err(|e| err!(Request(BadJson("Invalid JSON body: {e}"))))?; + + let event_id = services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: body.event_type.clone().into(), + content, + unsigned: Some(unsigned), + timestamp: appservice_info.and(body.timestamp), + ..Default::default() + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; + + services + .transaction_ids + .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes()); + + drop(state_lock); + + Ok(send_message_event::v3::Response { + event_id: event_id.into(), + }) +} diff --git a/src/api/router/args.rs b/src/api/router/args.rs index cefacac1c..38236db34 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -3,7 +3,9 @@ use std::{mem, ops::Deref}; use axum::{async_trait, body::Body, extract::FromRequest}; use bytes::{BufMut, BytesMut}; use conduit::{debug, err, trace, utils::string::EMPTY, Error, Result}; -use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId}; +use ruma::{ + api::IncomingRequest, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId, +}; use service::Services; use super::{auth, auth::Auth, request, request::Request}; @@ -40,10 +42,28 @@ where T: IncomingRequest + Send + Sync + 'static, { #[inline] - pub(crate) fn sender_user(&self) -> &UserId { self.sender_user.as_deref().expect("user is authenticated") } + pub(crate) fn sender(&self) -> (&UserId, &DeviceId) { (self.sender_user(), self.sender_device()) } #[inline] - pub(crate) fn origin(&self) -> &ServerName { self.origin.as_deref().expect("server is authenticated") } + pub(crate) fn sender_user(&self) -> &UserId { + self.sender_user + .as_deref() + .expect("user must be authenticated for this handler") + } + + #[inline] + pub(crate) fn sender_device(&self) -> &DeviceId { + self.sender_device + .as_deref() + .expect("user must be authenticated and device identified") + } + + #[inline] + pub(crate) fn origin(&self) -> &ServerName { + self.origin + .as_deref() + .expect("server must be authenticated for this handler") + } } #[async_trait] From 52e356d7805fd25c4e0b21757076f04d271d4241 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 28 Oct 2024 06:49:25 +0000 Subject: [PATCH 140/245] generate ActualDest https string on the fly Signed-off-by: Jason Volk --- src/service/resolver/actual.rs | 29 ++++++++++++++--------------- src/service/resolver/fed.rs | 19 ++++++++----------- src/service/sending/send.rs | 4 ++-- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index ea4b1100f..660498f7d 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -18,34 +18,33 @@ use crate::resolver::{ pub(crate) struct ActualDest { pub(crate) dest: FedDest, pub(crate) host: String, - pub(crate) string: String, pub(crate) cached: bool, } +impl ActualDest { + #[inline] + pub(crate) fn string(&self) -> String { self.dest.https_string() } +} + impl super::Service { #[tracing::instrument(skip_all, name = "resolve")] pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result { - let cached; - let cached_result = self.get_cached_destination(server_name); + let (result, cached) = if let Some(result) = self.get_cached_destination(server_name) { + (result, true) + } else { + self.validate_dest(server_name)?; + (self.resolve_actual_dest(server_name, true).await?, false) + }; let CachedDest { dest, host, .. - } = if let Some(result) = cached_result { - cached = true; - result - } else { - cached = false; - self.validate_dest(server_name)?; - self.resolve_actual_dest(server_name, true).await? - }; + } = result; - let string = dest.clone().into_https_string(); Ok(ActualDest { dest, host, - string, cached, }) } @@ -89,7 +88,7 @@ impl super::Service { debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); Ok(CachedDest { dest: actual_dest, - host: host.into_uri_string(), + host: host.uri_string(), expire: CachedDest::default_expire(), }) } @@ -109,7 +108,7 @@ impl super::Service { async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result { debug!("3: A .well-known file is available"); - *host = add_port_to_hostname(&delegated).into_uri_string(); + *host = add_port_to_hostname(&delegated).uri_string(); match get_ip_with_port(&delegated) { Some(host_and_port) => Self::actual_dest_3_1(host_and_port), None => { diff --git a/src/service/resolver/fed.rs b/src/service/resolver/fed.rs index 10cbbbdd0..79f71f13a 100644 --- a/src/service/resolver/fed.rs +++ b/src/service/resolver/fed.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, fmt, net::{IpAddr, SocketAddr}, }; @@ -29,24 +30,25 @@ pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { } impl FedDest { - pub(crate) fn into_https_string(self) -> String { + pub(crate) fn https_string(&self) -> String { match self { Self::Literal(addr) => format!("https://{addr}"), Self::Named(host, port) => format!("https://{host}{port}"), } } - pub(crate) fn into_uri_string(self) -> String { + pub(crate) fn uri_string(&self) -> String { match self { Self::Literal(addr) => addr.to_string(), Self::Named(host, port) => format!("{host}{port}"), } } - pub(crate) fn hostname(&self) -> String { + #[inline] + pub(crate) fn hostname(&self) -> Cow<'_, str> { match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), + Self::Literal(addr) => addr.ip().to_string().into(), + Self::Named(host, _) => host.into(), } } @@ -61,10 +63,5 @@ impl FedDest { } impl fmt::Display for FedDest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Named(host, port) => write!(f, "{host}{port}"), - Self::Literal(addr) => write!(f, "{addr}"), - } - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.uri_string().as_str()) } } diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 62da59ef2..2fbb39190 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -71,7 +71,7 @@ impl super::Service { trace!("Preparing request"); let mut http_request = req - .try_into_http_request::>(&actual.string, SATIR, &VERSIONS) + .try_into_http_request::>(actual.string().as_str(), SATIR, &VERSIONS) .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; self.sign_request::(dest, &mut http_request); @@ -107,7 +107,7 @@ where request_url = ?url, response_url = ?response.url(), "Received response from {}", - actual.string, + actual.string(), ); let mut http_response_builder = http::Response::builder() From d92f2c121f95f8d8beadc3a727b8a02376c46d3c Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 27 Oct 2024 12:19:45 -0400 Subject: [PATCH 141/245] document nginx needing request_uri Signed-off-by: strawberry --- docs/deploying/generic.md | 12 ++++++++---- src/api/router/auth.rs | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/deploying/generic.md b/docs/deploying/generic.md index 1e44ab541..31dc18456 100644 --- a/docs/deploying/generic.md +++ b/docs/deploying/generic.md @@ -119,12 +119,16 @@ is the recommended reverse proxy for new users and is very trivial to use (handles TLS, reverse proxy headers, etc transparently with proper defaults). Lighttpd is not supported as it seems to mess with the `X-Matrix` Authorization -header, making federation non-functional. If using Apache, you need to use -`nocanon` in your `ProxyPass` directive to prevent this (note that Apache -isn't very good as a general reverse proxy). +header, making federation non-functional. If a workaround is found, feel free to share to get it added to the documentation here. + +If using Apache, you need to use `nocanon` in your `ProxyPass` directive to prevent this (note that Apache isn't very good as a general reverse proxy and we discourage the usage of it if you can). + +If using Nginx, you need to give conduwuit the request URI using `$request_uri`, or like so: +- `proxy_pass http://127.0.0.1:6167$request_uri;` +- `proxy_pass http://127.0.0.1:6167;` Nginx users may need to set `proxy_buffering off;` if there are issues with -uploading media like images. +uploading media like images. This is due to Nginx storing the entire POST content in-memory (`/tmp`) and running out of memory if on low memory hardware. You will need to reverse proxy everything under following routes: - `/_matrix/` - core Matrix C-S and S-S APIs diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 6b1bb1a9f..31e71f2ff 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -247,8 +247,8 @@ async fn auth_server(services: &Services, request: &mut Request, body: Option<&C debug_error!("Failed to verify federation request from {origin}: {e}"); if request.parts.uri.to_string().contains('@') { warn!( - "Request uri contained '@' character. Make sure your reverse proxy gives Conduit the raw uri (apache: \ - use nocanon)" + "Request uri contained '@' character. Make sure your reverse proxy gives conduwuit the raw uri \ + (apache: use nocanon)" ); } From 065396f8f502e1b206c37b0d7dea92f79bfd8634 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 27 Oct 2024 12:37:44 -0400 Subject: [PATCH 142/245] better document allow_inbound_profile_lookup_federation_requests Signed-off-by: strawberry --- src/api/server/publicrooms.rs | 3 ++- src/api/server/query.rs | 6 +++++- src/core/config/mod.rs | 21 ++++++++++++++++----- src/service/globals/mod.rs | 4 ---- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/api/server/publicrooms.rs b/src/api/server/publicrooms.rs index af8a58464..f6c418592 100644 --- a/src/api/server/publicrooms.rs +++ b/src/api/server/publicrooms.rs @@ -20,7 +20,8 @@ pub(crate) async fn get_public_rooms_filtered_route( ) -> Result { if !services .globals - .allow_public_room_directory_over_federation() + .config + .allow_public_room_directory_over_federation { return Err(Error::BadRequest(ErrorKind::forbidden(), "Room directory is not public")); } diff --git a/src/api/server/query.rs b/src/api/server/query.rs index 348b8c6e9..bf515b3c7 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -63,7 +63,11 @@ pub(crate) async fn get_room_information_route( pub(crate) async fn get_profile_information_route( State(services): State, body: Ruma, ) -> Result { - if !services.globals.allow_profile_lookup_federation_requests() { + if !services + .globals + .config + .allow_inbound_profile_lookup_federation_requests + { return Err(Error::BadRequest( ErrorKind::forbidden(), "Profile lookup over federation is not allowed on this homeserver.", diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 04e44fd76..7a5c6d08d 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -458,11 +458,16 @@ pub struct Config { /// obtain the profiles of our local users from /// `/_matrix/federation/v1/query/profile` /// - /// This is inherently false if `allow_federation` is disabled + /// Increases privacy of your local user's such as display names, but some + /// remote users may get a false "this user does not exist" error when they + /// try to invite you to a DM or room. Also can protect against profile + /// spiders. /// - /// Defaults to true - #[serde(default = "true_fn")] - pub allow_profile_lookup_federation_requests: bool, + /// Defaults to true. + /// + /// This is inherently false if `allow_federation` is disabled + #[serde(default = "true_fn", alias = "allow_profile_lookup_federation_requests")] + pub allow_inbound_profile_lookup_federation_requests: bool, /// controls whether users are allowed to create rooms. /// appservices and admins are always allowed to create rooms @@ -1530,6 +1535,10 @@ impl fmt::Display for Config { line("Allow encryption", &self.allow_encryption.to_string()); line("Allow federation", &self.allow_federation.to_string()); line("Federation loopback", &self.federation_loopback.to_string()); + line( + "Require authentication for profile requests", + &self.require_auth_for_profile_requests.to_string(), + ); line( "Allow incoming federated presence requests (updates)", &self.allow_incoming_presence.to_string(), @@ -1577,7 +1586,9 @@ impl fmt::Display for Config { line("Allow device name federation", &self.allow_device_name_federation.to_string()); line( "Allow incoming profile lookup federation requests", - &self.allow_profile_lookup_federation_requests.to_string(), + &self + .allow_inbound_profile_lookup_federation_requests + .to_string(), ); line( "Auto deactivate banned room join attempts", diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 157c39440..0a7dda9f2 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -212,10 +212,6 @@ impl Service { pub fn turn_username(&self) -> &String { &self.config.turn_username } - pub fn allow_profile_lookup_federation_requests(&self) -> bool { - self.config.allow_profile_lookup_federation_requests - } - pub fn notification_push_path(&self) -> &String { &self.config.notification_push_path } pub fn emergency_password(&self) -> &Option { &self.config.emergency_password } From 85890ed42502a4672d21218b07fc7366f7027ef3 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 27 Oct 2024 13:21:16 -0400 Subject: [PATCH 143/245] remove some unnecessary HTML from admin commands Signed-off-by: strawberry --- src/admin/debug/commands.rs | 2 +- src/admin/federation/commands.rs | 28 +++++----------------------- src/admin/room/directory.rs | 29 +++++------------------------ src/admin/server/commands.rs | 5 ++++- src/core/config/mod.rs | 2 +- 5 files changed, 16 insertions(+), 50 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 0fd3c91bf..2aa6078fc 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -106,7 +106,7 @@ pub(super) async fn get_pdu(&self, event_id: Box) -> Result) -> Result { @@ -108,33 +108,15 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box) -> Result< rooms.sort_by_key(|r| r.1); rooms.reverse(); - let output_plain = format!( - "Rooms {user_id} shares with us ({}):\n{}", + let output = format!( + "Rooms {user_id} shares with us ({}):\n```\n{}\n```", rooms.len(), rooms .iter() - .map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}")) + .map(|(id, members, name)| format!("{id} | Members: {members} | Name: {name}")) .collect::>() .join("\n") ); - let output_html = format!( - "\n\t\t\n{}
    Rooms {user_id} shares with us \ - ({})
    idmembersname
    ", - rooms.len(), - rooms - .iter() - .fold(String::new(), |mut output, (id, members, name)| { - writeln!( - output, - "{}\t{}\t{}", - id, - members, - escape_html(name) - ) - .expect("should be able to write to string buffer"); - output - }) - ); - Ok(RoomMessageEventContent::text_html(output_plain, output_html)) + Ok(RoomMessageEventContent::text_markdown(output)) } diff --git a/src/admin/room/directory.rs b/src/admin/room/directory.rs index 1080356a8..0bdaf56d7 100644 --- a/src/admin/room/directory.rs +++ b/src/admin/room/directory.rs @@ -1,11 +1,9 @@ -use std::fmt::Write; - use clap::Subcommand; use conduit::Result; use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId}; -use crate::{escape_html, get_room_info, Command, PAGE_SIZE}; +use crate::{get_room_info, Command, PAGE_SIZE}; #[derive(Debug, Subcommand)] pub(crate) enum RoomDirectoryCommand { @@ -68,32 +66,15 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> return Ok(RoomMessageEventContent::text_plain("No more rooms.")); }; - let output_plain = format!( - "Rooms:\n{}", + let output = format!( + "Rooms (page {page}):\n```\n{}\n```", rooms .iter() - .map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}")) + .map(|(id, members, name)| format!("{id} | Members: {members} | Name: {name}")) .collect::>() .join("\n") ); - let output_html = format!( - "\n\t\t\n{}
    Room directory - page \ - {page}
    idmembersname
    ", - rooms - .iter() - .fold(String::new(), |mut output, (id, members, name)| { - writeln!( - output, - "{}\t{}\t{}", - escape_html(id.as_ref()), - members, - escape_html(name.as_ref()) - ) - .expect("should be able to write to string buffer"); - output - }) - ); - Ok(RoomMessageEventContent::text_html(output_plain, output_html)) + Ok(RoomMessageEventContent::text_markdown(output)) }, } } diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index de6ad98ad..f5879b037 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -21,7 +21,10 @@ pub(super) async fn uptime(&self) -> Result { #[admin_command] pub(super) async fn show_config(&self) -> Result { // Construct and send the response - Ok(RoomMessageEventContent::text_plain(format!("{}", self.services.globals.config))) + Ok(RoomMessageEventContent::text_markdown(format!( + "```\n{}\n```", + self.services.globals.config + ))) } #[admin_command] diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 7a5c6d08d..512cb48b4 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1441,7 +1441,7 @@ impl Config { impl fmt::Display for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Active config values:\n\n").expect("wrote line to formatter stream"); + writeln!(f, "Active config values:\n").expect("wrote line to formatter stream"); let mut line = |key: &str, val: &str| { writeln!(f, "{key}: {val}").expect("wrote line to formatter stream"); }; From 0a281241efdc536d950f24edd4805364a8bbfd97 Mon Sep 17 00:00:00 2001 From: strawberry Date: Mon, 28 Oct 2024 16:53:53 -0400 Subject: [PATCH 144/245] bump few dependencies, bump ruwuma Signed-off-by: strawberry --- Cargo.lock | 68 +++++++++++++++++++++++++++--------------------------- Cargo.toml | 10 ++++---- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a8acce7d3..44856753f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -290,7 +290,7 @@ dependencies = [ "hyper", "hyper-util", "pin-project-lite", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-pemfile", "rustls-pki-types", "tokio", @@ -310,7 +310,7 @@ dependencies = [ "http", "http-body-util", "pin-project", - "rustls 0.23.15", + "rustls 0.23.16", "tokio", "tokio-rustls", "tokio-util", @@ -770,7 +770,7 @@ dependencies = [ "hyper-util", "log", "ruma", - "rustls 0.23.15", + "rustls 0.23.16", "sd-notify", "sentry", "sentry-tower", @@ -1202,9 +1202,9 @@ dependencies = [ [[package]] name = "fdeflate" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab" +checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb" dependencies = [ "simd-adler32", ] @@ -1708,7 +1708,7 @@ dependencies = [ "http", "hyper", "hyper-util", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -2771,7 +2771,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.0.0", - "rustls 0.23.15", + "rustls 0.23.16", "socket2", "thiserror", "tokio", @@ -2788,7 +2788,7 @@ dependencies = [ "rand", "ring", "rustc-hash 2.0.0", - "rustls 0.23.15", + "rustls 0.23.16", "slab", "thiserror", "tinyvec", @@ -2902,9 +2902,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "async-compression", "base64 0.22.1", @@ -2928,7 +2928,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types", @@ -2977,7 +2977,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "assign", "js_int", @@ -2999,7 +2999,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "js_int", "ruma-common", @@ -3011,7 +3011,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "as_variant", "assign", @@ -3034,7 +3034,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "as_variant", "base64 0.22.1", @@ -3064,7 +3064,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3088,7 +3088,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "bytes", "http", @@ -3106,7 +3106,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "js_int", "thiserror", @@ -3115,7 +3115,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "js_int", "ruma-common", @@ -3125,7 +3125,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "cfg-if", "once_cell", @@ -3141,7 +3141,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "js_int", "ruma-common", @@ -3153,7 +3153,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "headers", "http", @@ -3166,7 +3166,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3182,7 +3182,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=39c1addd37a4eed612ac1135edc2cccd9d331d5e#39c1addd37a4eed612ac1135edc2cccd9d331d5e" +source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" dependencies = [ "futures-util", "itertools 0.13.0", @@ -3258,9 +3258,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" dependencies = [ "bitflags 2.6.0", "errno", @@ -3285,9 +3285,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.15" +version = "0.23.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" +checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" dependencies = [ "aws-lc-rs", "log", @@ -3563,18 +3563,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.213" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.213" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", @@ -4106,7 +4106,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.15", + "rustls 0.23.16", "rustls-pki-types", "tokio", ] @@ -4472,7 +4472,7 @@ dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-pki-types", "url", "webpki-roots", diff --git a/Cargo.toml b/Cargo.toml index 2f9f196b6..e406c9e15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -128,10 +128,10 @@ features = [ ] [workspace.dependencies.rustls] -version = "0.23.13" +version = "0.23.16" [workspace.dependencies.reqwest] -version = "0.12.8" +version = "0.12.9" default-features = false features = [ "rustls-tls-native-roots", @@ -141,7 +141,7 @@ features = [ ] [workspace.dependencies.serde] -version = "1.0.209" +version = "1.0.214" default-features = false features = ["rc"] @@ -257,7 +257,7 @@ features = [ ] [workspace.dependencies.hyper-util] -# 0.1.9 causes DNS issues +# 0.1.9 and above causes DNS issues version = "=0.1.8" default-features = false features = [ @@ -315,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "39c1addd37a4eed612ac1135edc2cccd9d331d5e" +rev = "dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" features = [ "compat", "rand", From c71db93e225b44f15c652f4fbe0befaad508e48e Mon Sep 17 00:00:00 2001 From: strawberry Date: Mon, 28 Oct 2024 18:28:56 -0400 Subject: [PATCH 145/245] implement admin command to force join list of local users Signed-off-by: strawberry --- src/admin/user/commands.rs | 139 ++++++++++++++++++++++++++++++++++++- src/admin/user/mod.rs | 15 ++++ 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index fb6d2bf1b..d6946b4e0 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -1,7 +1,11 @@ use std::{collections::BTreeMap, fmt::Write as _}; use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room}; -use conduit::{error, info, is_equal_to, utils, warn, PduBuilder, Result}; +use conduit::{ + debug_warn, error, info, is_equal_to, + utils::{self, ReadyExt}, + warn, PduBuilder, Result, +}; use conduit_api::client::{leave_all_rooms, update_avatar_url, update_displayname}; use futures::StreamExt; use ruma::{ @@ -376,6 +380,139 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result Result { + const REASON: &str = "Bulk force joining this room as initiated by the server admin."; + + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { + return Ok(RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + )); + } + + if !yes_i_want_to_do_this { + return Ok(RoomMessageEventContent::notice_markdown( + "You must pass the --yes-i-want-to-do-this-flag to ensure you really want to force bulk join all \ + specified local users.", + )); + } + + let Ok(admin_room) = self.services.admin.get_admin_room().await else { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not an admin room to check for server admins.", + )); + }; + + let (room_id, servers) = self + .services + .rooms + .alias + .resolve_with_servers(&room_id, None) + .await?; + + if !self + .services + .rooms + .state_cache + .server_in_room(self.services.globals.server_name(), &room_id) + .await + { + return Ok(RoomMessageEventContent::notice_markdown("We are not joined in this room.")); + } + + let server_admins: Vec<_> = self + .services + .rooms + .state_cache + .active_local_users_in_room(&admin_room) + .map(ToOwned::to_owned) + .collect() + .await; + + if !self + .services + .rooms + .state_cache + .room_members(&room_id) + .ready_any(|user_id| server_admins.contains(&user_id.to_owned())) + .await + { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not a single server admin in the room.", + )); + } + + let usernames = self + .body + .to_vec() + .drain(1..self.body.len().saturating_sub(1)) + .collect::>(); + + let mut user_ids: Vec = Vec::with_capacity(usernames.len()); + + for username in usernames { + match parse_active_local_user_id(self.services, username).await { + Ok(user_id) => { + // don't make the server service account join + if user_id == self.services.globals.server_user { + self.services + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "{username} is the server service account, skipping over" + ))) + .await + .ok(); + continue; + } + + user_ids.push(user_id); + }, + Err(e) => { + self.services + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "{username} is not a valid username, skipping over: {e}" + ))) + .await + .ok(); + continue; + }, + } + } + + let mut failed_joins: usize = 0; + let mut successful_joins: usize = 0; + + for user_id in user_ids { + match join_room_by_id_helper( + self.services, + &user_id, + &room_id, + Some(String::from(REASON)), + &servers, + None, + &None, + ) + .await + { + Ok(_res) => { + successful_joins = successful_joins.saturating_add(1); + }, + Err(e) => { + debug_warn!("Failed force joining {user_id} to {room_id} during bulk join: {e}"); + failed_joins = failed_joins.saturating_add(1); + }, + }; + } + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{successful_joins} local users have been joined to {room_id}. {failed_joins} joins failed.", + ))) +} + #[admin_command] pub(super) async fn force_join_room( &self, user_id: String, room_id: OwnedRoomOrAliasId, diff --git a/src/admin/user/mod.rs b/src/admin/user/mod.rs index e7bb5c732..e15682692 100644 --- a/src/admin/user/mod.rs +++ b/src/admin/user/mod.rs @@ -124,4 +124,19 @@ pub(super) enum UserCommand { RedactEvent { event_id: Box, }, + + /// - Force joins a specified list of local users to join the specified + /// room. + /// + /// Specify a codeblock of usernames. + /// + /// At least 1 server admin must be in the room to prevent abuse. + /// + /// Requires the `--yes-i-want-to-do-this` flag. + ForceJoinListOfLocalUsers { + room_id: OwnedRoomOrAliasId, + + #[arg(long)] + yes_i_want_to_do_this: bool, + }, } From 567a4cb4417726d4400f81e1bedea12e46fac439 Mon Sep 17 00:00:00 2001 From: strawberry Date: Mon, 28 Oct 2024 19:06:53 -0400 Subject: [PATCH 146/245] implement admin command to force join all local users to room Signed-off-by: strawberry --- src/admin/user/commands.rs | 100 +++++++++++++++++++++++++++++++++++-- src/admin/user/mod.rs | 14 +++++- 2 files changed, 109 insertions(+), 5 deletions(-) diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index d6946b4e0..531ce490d 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -18,7 +18,7 @@ use ruma::{ tag::{TagEvent, TagEventContent, TagInfo}, RoomAccountDataEventType, StateEventType, }, - EventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, + EventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, UserId, }; use crate::{ @@ -27,6 +27,7 @@ use crate::{ }; const AUTO_GEN_PASSWORD_LENGTH: usize = 25; +const BULK_JOIN_REASON: &str = "Bulk force joining this room as initiated by the server admin."; #[admin_command] pub(super) async fn list_users(&self) -> Result { @@ -384,8 +385,6 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result Result { - const REASON: &str = "Bulk force joining this room as initiated by the server admin."; - if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" { return Ok(RoomMessageEventContent::text_plain( @@ -491,7 +490,100 @@ pub(super) async fn force_join_list_of_local_users( self.services, &user_id, &room_id, - Some(String::from(REASON)), + Some(String::from(BULK_JOIN_REASON)), + &servers, + None, + &None, + ) + .await + { + Ok(_res) => { + successful_joins = successful_joins.saturating_add(1); + }, + Err(e) => { + debug_warn!("Failed force joining {user_id} to {room_id} during bulk join: {e}"); + failed_joins = failed_joins.saturating_add(1); + }, + }; + } + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{successful_joins} local users have been joined to {room_id}. {failed_joins} joins failed.", + ))) +} + +#[admin_command] +pub(super) async fn force_join_all_local_users( + &self, room_id: OwnedRoomOrAliasId, yes_i_want_to_do_this: bool, +) -> Result { + if !yes_i_want_to_do_this { + return Ok(RoomMessageEventContent::notice_markdown( + "You must pass the --yes-i-want-to-do-this-flag to ensure you really want to force bulk join all local \ + users.", + )); + } + + let Ok(admin_room) = self.services.admin.get_admin_room().await else { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not an admin room to check for server admins.", + )); + }; + + let (room_id, servers) = self + .services + .rooms + .alias + .resolve_with_servers(&room_id, None) + .await?; + + if !self + .services + .rooms + .state_cache + .server_in_room(self.services.globals.server_name(), &room_id) + .await + { + return Ok(RoomMessageEventContent::notice_markdown("We are not joined in this room.")); + } + + let server_admins: Vec<_> = self + .services + .rooms + .state_cache + .active_local_users_in_room(&admin_room) + .map(ToOwned::to_owned) + .collect() + .await; + + if !self + .services + .rooms + .state_cache + .room_members(&room_id) + .ready_any(|user_id| server_admins.contains(&user_id.to_owned())) + .await + { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not a single server admin in the room.", + )); + } + + let mut failed_joins: usize = 0; + let mut successful_joins: usize = 0; + + for user_id in &self + .services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::>() + .await + { + match join_room_by_id_helper( + self.services, + user_id, + &room_id, + Some(String::from(BULK_JOIN_REASON)), &servers, None, &None, diff --git a/src/admin/user/mod.rs b/src/admin/user/mod.rs index e15682692..649cdfb87 100644 --- a/src/admin/user/mod.rs +++ b/src/admin/user/mod.rs @@ -130,7 +130,7 @@ pub(super) enum UserCommand { /// /// Specify a codeblock of usernames. /// - /// At least 1 server admin must be in the room to prevent abuse. + /// At least 1 server admin must be in the room to reduce abuse. /// /// Requires the `--yes-i-want-to-do-this` flag. ForceJoinListOfLocalUsers { @@ -139,4 +139,16 @@ pub(super) enum UserCommand { #[arg(long)] yes_i_want_to_do_this: bool, }, + + /// - Force joins all local users to the specified room. + /// + /// At least 1 server admin must be in the room to reduce abuse. + /// + /// Requires the `--yes-i-want-to-do-this` flag. + ForceJoinAllLocalUsers { + room_id: OwnedRoomOrAliasId, + + #[arg(long)] + yes_i_want_to_do_this: bool, + }, } From 354dc9e703a16dfca96bfe59cba0ba83e65725cf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 29 Oct 2024 00:08:41 +0000 Subject: [PATCH 147/245] add map accessor to Database; move cork interface Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 2 +- src/database/cork.rs | 16 +++++++++++++++- src/database/database.rs | 26 ++++++++++---------------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 2aa6078fc..db1028588 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -838,7 +838,7 @@ pub(super) async fn database_stats( let map_name = map.as_ref().map_or(EMPTY, String::as_str); let mut out = String::new(); - for (name, map) in self.services.db.iter_maps() { + for (name, map) in self.services.db.iter() { if !map_name.is_empty() && *map_name != *name { continue; } diff --git a/src/database/cork.rs b/src/database/cork.rs index 26c520a28..5fe5fd7ab 100644 --- a/src/database/cork.rs +++ b/src/database/cork.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::Engine; +use crate::{Database, Engine}; pub struct Cork { db: Arc, @@ -8,6 +8,20 @@ pub struct Cork { sync: bool, } +impl Database { + #[inline] + #[must_use] + pub fn cork(&self) -> Cork { Cork::new(&self.db, false, false) } + + #[inline] + #[must_use] + pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) } + + #[inline] + #[must_use] + pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } +} + impl Cork { #[inline] pub(super) fn new(db: &Arc, flush: bool, sync: bool) -> Self { diff --git a/src/database/database.rs b/src/database/database.rs index 4c29c840c..099df87d4 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -1,9 +1,8 @@ use std::{ops::Index, sync::Arc}; -use conduit::{Result, Server}; +use conduit::{err, Result, Server}; use crate::{ - cork::Cork, maps, maps::{Maps, MapsKey, MapsVal}, Engine, Map, @@ -11,7 +10,7 @@ use crate::{ pub struct Database { pub db: Arc, - map: Maps, + maps: Maps, } impl Database { @@ -20,24 +19,19 @@ impl Database { let db = Engine::open(server)?; Ok(Arc::new(Self { db: db.clone(), - map: maps::open(&db)?, + maps: maps::open(&db)?, })) } #[inline] - #[must_use] - pub fn cork(&self) -> Cork { Cork::new(&self.db, false, false) } - - #[inline] - #[must_use] - pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) } - - #[inline] - #[must_use] - pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } + pub fn get(&self, name: &str) -> Result<&Arc> { + self.maps + .get(name) + .ok_or_else(|| err!(Request(NotFound("column not found")))) + } #[inline] - pub fn iter_maps(&self) -> impl Iterator + Send + '_ { self.map.iter() } + pub fn iter(&self) -> impl Iterator + Send + '_ { self.maps.iter() } #[inline] #[must_use] @@ -52,7 +46,7 @@ impl Index<&str> for Database { type Output = Arc; fn index(&self, name: &str) -> &Self::Output { - self.map + self.maps .get(name) .expect("column in database does not exist") } From 8ed9d49b73923c39524d442ed9c7878d99ff2189 Mon Sep 17 00:00:00 2001 From: strawberry Date: Thu, 31 Oct 2024 14:41:35 -0400 Subject: [PATCH 148/245] skip new flakey complement test Signed-off-by: strawberry --- bin/complement | 2 +- tests/test_results/complement/test_results.jsonl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/bin/complement b/bin/complement index 601edb5a7..a1db4b325 100755 --- a/bin/complement +++ b/bin/complement @@ -18,7 +18,7 @@ RESULTS_FILE="$3" OCI_IMAGE="complement-conduwuit:main" # Complement tests that are skipped due to flakiness/reliability issues -SKIPPED_COMPLEMENT_TESTS='-skip=TestClientSpacesSummary.*|TestJoinFederatedRoomFromApplicationServiceBridgeUser.*|TestJumpToDateEndpoint.*' +SKIPPED_COMPLEMENT_TESTS='-skip=TestClientSpacesSummary.*|TestJoinFederatedRoomFromApplicationServiceBridgeUser.*|TestJumpToDateEndpoint.*|TestUnbanViaInvite.*' # $COMPLEMENT_SRC needs to be a directory to Complement source code if [ -f "$COMPLEMENT_SRC" ]; then diff --git a/tests/test_results/complement/test_results.jsonl b/tests/test_results/complement/test_results.jsonl index ff695bb74..575a22fe8 100644 --- a/tests/test_results/complement/test_results.jsonl +++ b/tests/test_results/complement/test_results.jsonl @@ -225,7 +225,6 @@ {"Action":"pass","Test":"TestToDeviceMessagesOverFederation/good_connectivity"} {"Action":"pass","Test":"TestToDeviceMessagesOverFederation/interrupted_connectivity"} {"Action":"fail","Test":"TestToDeviceMessagesOverFederation/stopped_server"} -{"Action":"pass","Test":"TestUnbanViaInvite"} {"Action":"fail","Test":"TestUnknownEndpoints"} {"Action":"pass","Test":"TestUnknownEndpoints/Client-server_endpoints"} {"Action":"fail","Test":"TestUnknownEndpoints/Key_endpoints"} From 240c78e8101da122e35986c7c1414b4f2d655d31 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 1 Nov 2024 00:54:21 -0400 Subject: [PATCH 149/245] strong-type URL for URL previews to Url type Signed-off-by: strawberry --- src/api/client/media.rs | 25 ++++++++++++++++++------- src/api/client/media_legacy.rs | 25 ++++++++++++++++--------- src/service/media/preview.rs | 32 ++++++++++++-------------------- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 120127116..716936184 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -11,6 +11,7 @@ use conduit_service::{ media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, MXC_LENGTH}, Services, }; +use reqwest::Url; use ruma::{ api::client::{ authenticated_media::{ @@ -165,23 +166,33 @@ pub(crate) async fn get_media_preview_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !services.media.url_preview_allowed(url) { + let url = Url::parse(&body.url).map_err(|e| { + err!(Request(InvalidParam( + debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}") + ))) + })?; + + if !services.media.url_preview_allowed(&url) { return Err!(Request(Forbidden( debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") ))); } - let preview = services.media.get_url_preview(url).await.map_err(|error| { - err!(Request(Unknown( - debug_error!(%sender_user, %url, ?error, "Failed to fetch URL preview.") - ))) - })?; + let preview = services + .media + .get_url_preview(&url) + .await + .map_err(|error| { + err!(Request(Unknown( + debug_error!(%sender_user, %url, "Failed to fetch URL preview: {error}") + ))) + })?; serde_json::value::to_raw_value(&preview) .map(get_media_preview::v1::Response::from_raw_value) .map_err(|error| { err!(Request(Unknown( - debug_error!(%sender_user, %url, ?error, "Failed to parse URL preview.") + debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}") ))) }) } diff --git a/src/api/client/media_legacy.rs b/src/api/client/media_legacy.rs index e87b9a2b2..f6837462e 100644 --- a/src/api/client/media_legacy.rs +++ b/src/api/client/media_legacy.rs @@ -8,6 +8,7 @@ use conduit::{ Err, Result, }; use conduit_service::media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN}; +use reqwest::Url; use ruma::{ api::client::media::{ create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config, @@ -55,25 +56,31 @@ pub(crate) async fn get_media_preview_legacy_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !services.media.url_preview_allowed(url) { + let url = Url::parse(&body.url).map_err(|e| { + err!(Request(InvalidParam( + debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}") + ))) + })?; + + if !services.media.url_preview_allowed(&url) { return Err!(Request(Forbidden( debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") ))); } - let preview = services.media.get_url_preview(url).await.map_err(|e| { + let preview = services.media.get_url_preview(&url).await.map_err(|e| { err!(Request(Unknown( debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}") ))) })?; - let res = serde_json::value::to_raw_value(&preview).map_err(|e| { - err!(Request(Unknown( - debug_error!(%sender_user, %url, "Failed to parse a URL preview: {e}") - ))) - })?; - - Ok(get_media_preview::v3::Response::from_raw_value(res)) + serde_json::value::to_raw_value(&preview) + .map(get_media_preview::v3::Response::from_raw_value) + .map_err(|error| { + err!(Request(Unknown( + debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}") + ))) + }) } /// # `GET /_matrix/media/v1/preview_url` diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs index 6b1473838..acc9d8ed1 100644 --- a/src/service/media/preview.rs +++ b/src/service/media/preview.rs @@ -1,6 +1,6 @@ use std::{io::Cursor, time::SystemTime}; -use conduit::{debug, utils, warn, Err, Result}; +use conduit::{debug, utils, Err, Result}; use conduit_core::implement; use image::ImageReader as ImgReader; use ipaddress::IPAddress; @@ -70,30 +70,30 @@ pub async fn download_image(&self, url: &str) -> Result { } #[implement(Service)] -pub async fn get_url_preview(&self, url: &str) -> Result { - if let Ok(preview) = self.db.get_url_preview(url).await { +pub async fn get_url_preview(&self, url: &Url) -> Result { + if let Ok(preview) = self.db.get_url_preview(url.as_str()).await { return Ok(preview); } // ensure that only one request is made per URL - let _request_lock = self.url_preview_mutex.lock(url).await; + let _request_lock = self.url_preview_mutex.lock(url.as_str()).await; - match self.db.get_url_preview(url).await { + match self.db.get_url_preview(url.as_str()).await { Ok(preview) => Ok(preview), Err(_) => self.request_url_preview(url).await, } } #[implement(Service)] -async fn request_url_preview(&self, url: &str) -> Result { - if let Ok(ip) = IPAddress::parse(url) { +async fn request_url_preview(&self, url: &Url) -> Result { + if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) { if !self.services.globals.valid_cidr_range(&ip) { return Err!(BadServerResponse("Requesting from this address is forbidden")); } } let client = &self.services.client.url_preview; - let response = client.head(url).send().await?; + let response = client.head(url.as_str()).send().await?; if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { @@ -111,12 +111,12 @@ async fn request_url_preview(&self, url: &str) -> Result { return Err!(Request(Unknown("Unknown Content-Type"))); }; let data = match content_type { - html if html.starts_with("text/html") => self.download_html(url).await?, - img if img.starts_with("image/") => self.download_image(url).await?, + html if html.starts_with("text/html") => self.download_html(url.as_str()).await?, + img if img.starts_with("image/") => self.download_image(url.as_str()).await?, _ => return Err!(Request(Unknown("Unsupported Content-Type"))), }; - self.set_url_preview(url, &data).await?; + self.set_url_preview(url.as_str(), &data).await?; Ok(data) } @@ -159,15 +159,7 @@ async fn download_html(&self, url: &str) -> Result { } #[implement(Service)] -pub fn url_preview_allowed(&self, url_str: &str) -> bool { - let url: Url = match Url::parse(url_str) { - Ok(u) => u, - Err(e) => { - warn!("Failed to parse URL from a str: {}", e); - return false; - }, - }; - +pub fn url_preview_allowed(&self, url: &Url) -> bool { if ["http", "https"] .iter() .all(|&scheme| scheme != url.scheme().to_lowercase()) From 6cbaef2d12b24765dc16c0478b58c2e76dd972cd Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 2 Nov 2024 13:12:14 -0400 Subject: [PATCH 150/245] always set RUST_BACKTRACE=full in OCI images Signed-off-by: strawberry --- nix/pkgs/complement/default.nix | 1 + nix/pkgs/oci-image/default.nix | 3 +++ 2 files changed, 4 insertions(+) diff --git a/nix/pkgs/complement/default.nix b/nix/pkgs/complement/default.nix index 80e9ce273..399c4449b 100644 --- a/nix/pkgs/complement/default.nix +++ b/nix/pkgs/complement/default.nix @@ -96,6 +96,7 @@ dockerTools.buildImage { Env = [ "SSL_CERT_FILE=/complement/ca/ca.crt" "CONDUWUIT_CONFIG=${./config.toml}" + "RUST_BACKTRACE=full" ]; ExposedPorts = { diff --git a/nix/pkgs/oci-image/default.nix b/nix/pkgs/oci-image/default.nix index 5078523bc..9b6413106 100644 --- a/nix/pkgs/oci-image/default.nix +++ b/nix/pkgs/oci-image/default.nix @@ -24,5 +24,8 @@ dockerTools.buildLayeredImage { Cmd = [ "${lib.getExe main}" ]; + Env = [ + "RUST_BACKTRACE=full" + ]; }; } From ee6af6c90e5f941584429b9e890bffa358b23720 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 2 Nov 2024 18:46:20 -0400 Subject: [PATCH 151/245] drop report delay response range to 2-5 secs Signed-off-by: strawberry --- src/api/client/report.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 143c13e56..32a254d8f 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -180,7 +180,7 @@ async fn is_event_report_valid( /// random delay sending a response per spec suggestion regarding /// enumerating for potential events existing in our server. async fn delay_response() { - let time_to_wait = rand::thread_rng().gen_range(3..10); + let time_to_wait = rand::thread_rng().gen_range(2..5); debug_info!("Got successful /report request, waiting {time_to_wait} seconds before sending successful response."); sleep(Duration::from_secs(time_to_wait)).await; } From 9466aeb08876472f49da6ce4b2fb673ff3598c04 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 2 Nov 2024 18:52:25 -0400 Subject: [PATCH 152/245] remove some unnecessary debug prints on notices Signed-off-by: strawberry --- src/api/client/account.rs | 18 ++++++++++-------- src/api/client/report.rs | 25 ++++++++++++++----------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/api/client/account.rs b/src/api/client/account.rs index 97d36839d..87e73c5a5 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -100,8 +100,8 @@ pub(crate) async fn register_route( if !services.globals.allow_registration() && body.appservice_info.is_none() { info!( "Registration disabled and request not from known appservice, rejecting registration attempt for username \ - {:?}", - body.username + \"{}\"", + body.username.as_deref().unwrap_or("") ); return Err(Error::BadRequest(ErrorKind::forbidden(), "Registration has been disabled.")); } @@ -114,8 +114,8 @@ pub(crate) async fn register_route( { info!( "Guest registration disabled / registration enabled with token configured, rejecting guest registration \ - attempt, initial device name: {:?}", - body.initial_device_display_name + attempt, initial device name: \"{}\"", + body.initial_device_display_name.as_deref().unwrap_or("") ); return Err(Error::BadRequest( ErrorKind::GuestAccessForbidden, @@ -128,8 +128,8 @@ pub(crate) async fn register_route( if is_guest && services.users.count().await < 2 { warn!( "Guest account attempted to register before a real admin user has been registered, rejecting \ - registration. Guest's initial device name: {:?}", - body.initial_device_display_name + registration. Guest's initial device name: \"{}\"", + body.initial_device_display_name.as_deref().unwrap_or("") ); return Err(Error::BadRequest(ErrorKind::forbidden(), "Registration temporarily disabled.")); } @@ -312,12 +312,14 @@ pub(crate) async fn register_route( debug_info!(%user_id, %device_id, "User account was created"); - let device_display_name = body.initial_device_display_name.clone().unwrap_or_default(); + let device_display_name = body.initial_device_display_name.as_deref().unwrap_or(""); // log in conduit admin channel if a non-guest user registered if body.appservice_info.is_none() && !is_guest { if !device_display_name.is_empty() { - info!("New user \"{user_id}\" registered on this server with device display name: {device_display_name}"); + info!( + "New user \"{user_id}\" registered on this server with device display name: \"{device_display_name}\"" + ); if services.globals.config.admin_room_notices { services diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 32a254d8f..e20fa8c22 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -33,10 +33,18 @@ pub(crate) async fn report_room_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); info!( - "Received room report by user {sender_user} for room {} with reason: {:?}", - body.room_id, body.reason + "Received room report by user {sender_user} for room {} with reason: \"{}\"", + body.room_id, + body.reason.as_deref().unwrap_or("") ); + if body.reason.as_ref().is_some_and(|s| s.len() > 750) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Reason too long, should be 750 characters or fewer", + )); + }; + delay_response().await; if !services @@ -50,13 +58,6 @@ pub(crate) async fn report_room_route( ))); } - if body.reason.as_ref().is_some_and(|s| s.len() > 750) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Reason too long, should be 750 characters or fewer", - )); - }; - // send admin room message that we received the report with an @room ping for // urgency services @@ -85,8 +86,10 @@ pub(crate) async fn report_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); info!( - "Received event report by user {sender_user} for room {} and event ID {}, with reason: {:?}", - body.room_id, body.event_id, body.reason + "Received event report by user {sender_user} for room {} and event ID {}, with reason: \"{}\"", + body.room_id, + body.event_id, + body.reason.as_deref().unwrap_or("") ); delay_response().await; From 6f37a251fb5945c9f431ffeb3f32fcb5d3bbc470 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 2 Nov 2024 20:55:40 -0400 Subject: [PATCH 153/245] allow taking room aliases for `auto_join_rooms` config option Signed-off-by: strawberry --- src/admin/user/commands.rs | 20 ++++++++++++++++---- src/api/client/account.rs | 13 +++++++++---- src/core/config/mod.rs | 11 ++++++----- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 531ce490d..444a7f372 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -108,24 +108,29 @@ pub(super) async fn create_user(&self, username: String, password: Option { + self.services + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "Failed to automatically join room {room} for user {user_id}: {e}" + ))) + .await + .ok(); // don't return this error so we don't fail registrations error!("Failed to automatically join room {room} for user {user_id}: {e}"); }, diff --git a/src/api/client/account.rs b/src/api/client/account.rs index 87e73c5a5..c340f5295 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -398,23 +398,28 @@ pub(crate) async fn register_route( && (services.globals.allow_guests_auto_join_rooms() || !is_guest) { for room in &services.globals.config.auto_join_rooms { + let Ok(room_id) = services.rooms.alias.resolve(room).await else { + error!("Failed to resolve room alias to room ID when attempting to auto join {room}, skipping"); + continue; + }; + if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room) + .server_in_room(services.globals.server_name(), &room_id) .await { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; } - if let Some(room_id_server_name) = room.server_name() { + if let Some(room_server_name) = room.server_name() { if let Err(e) = join_room_by_id_helper( &services, &user_id, - room, + &room_id, Some("Automatically joining this room upon registration".to_owned()), - &[room_id_server_name.to_owned(), services.globals.server_name().to_owned()], + &[services.globals.server_name().to_owned(), room_server_name.to_owned()], None, &body.appservice_info, ) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 512cb48b4..a6216da20 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -18,7 +18,8 @@ pub use figment::{value::Value as FigmentValue, Figment}; use itertools::Itertools; use regex::RegexSet; use ruma::{ - api::client::discovery::discover_support::ContactRole, OwnedRoomId, OwnedServerName, OwnedUserId, RoomVersionId, + api::client::discovery::discover_support::ContactRole, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, + RoomVersionId, }; use serde::{de::IgnoredAny, Deserialize}; use url::Url; @@ -653,13 +654,13 @@ pub struct Config { #[serde(default = "default_turn_ttl")] pub turn_ttl: u64, - /// List/vector of room **IDs** that conduwuit will make newly registered - /// users join. The room IDs specified must be rooms that you have joined - /// at least once on the server, and must be public. + /// List/vector of room IDs or room aliases that conduwuit will make newly + /// registered users join. The rooms specified must be rooms that you + /// have joined at least once on the server, and must be public. /// /// No default. #[serde(default = "Vec::new")] - pub auto_join_rooms: Vec, + pub auto_join_rooms: Vec, /// Config option to automatically deactivate the account of any user who /// attempts to join a: From 038787106365cebb0c538af043cc90028e200cc3 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 2 Nov 2024 21:20:36 -0400 Subject: [PATCH 154/245] add workaround for matrix-appservice-irc using historical localparts see https://github.com/matrix-org/matrix-appservice-irc/issues/1780 Signed-off-by: strawberry --- src/api/client/account.rs | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/api/client/account.rs b/src/api/client/account.rs index c340f5295..5ed4b3127 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -48,10 +48,19 @@ pub(crate) async fn get_register_available_route( State(services): State, InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { + // workaround for https://github.com/matrix-org/matrix-appservice-irc/issues/1780 due to inactivity of fixing the issue + let is_matrix_appservice_irc = body.appservice_info.as_ref().is_some_and(|appservice| { + appservice.registration.id == "irc" + || appservice.registration.id.contains("matrix-appservice-irc") + || appservice.registration.id.contains("matrix_appservice_irc") + }); + // Validate user id let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services.globals.server_name()) .ok() - .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) + .filter(|user_id| { + (!user_id.is_historical() || is_matrix_appservice_irc) && services.globals.user_is_local(user_id) + }) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough @@ -134,12 +143,22 @@ pub(crate) async fn register_route( return Err(Error::BadRequest(ErrorKind::forbidden(), "Registration temporarily disabled.")); } + // workaround for https://github.com/matrix-org/matrix-appservice-irc/issues/1780 due to inactivity of fixing the issue + let is_matrix_appservice_irc = body.appservice_info.as_ref().is_some_and(|appservice| { + appservice.registration.id == "irc" + || appservice.registration.id.contains("matrix-appservice-irc") + || appservice.registration.id.contains("matrix_appservice_irc") + }); + let user_id = match (&body.username, is_guest) { (Some(username), false) => { let proposed_user_id = UserId::parse_with_server_name(username.to_lowercase(), services.globals.server_name()) .ok() - .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) + .filter(|user_id| { + (!user_id.is_historical() || is_matrix_appservice_irc) + && services.globals.user_is_local(user_id) + }) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; if services.users.exists(&proposed_user_id).await { From 1fbfc983e9606752770286f42d5812a34a820e63 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 29 Oct 2024 10:53:56 +0000 Subject: [PATCH 155/245] optimize FedDest::Named port Signed-off-by: Jason Volk --- src/service/resolver/actual.rs | 39 +++++++++++++++++++++++++--------- src/service/resolver/fed.rs | 26 +++++++++++++++++------ src/service/resolver/tests.rs | 4 ++-- 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 660498f7d..61eedca51 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -9,9 +9,9 @@ use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; use ruma::ServerName; -use crate::resolver::{ +use super::{ cache::{CachedDest, CachedOverride}, - fed::{add_port_to_hostname, get_ip_with_port, FedDest}, + fed::{add_port_to_hostname, get_ip_with_port, FedDest, PortString}, }; #[derive(Clone, Debug)] @@ -77,12 +77,12 @@ impl super::Service { let host = if let Ok(addr) = host.parse::() { FedDest::Literal(addr) } else if let Ok(addr) = host.parse::() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) + FedDest::Named(addr.to_string(), FedDest::default_port()) } else if let Some(pos) = host.find(':') { let (host, port) = host.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) + FedDest::Named(host.to_owned(), port.try_into().unwrap_or_else(|_| FedDest::default_port())) } else { - FedDest::Named(host, ":8448".to_owned()) + FedDest::Named(host, FedDest::default_port()) }; debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); @@ -103,7 +103,10 @@ impl super::Service { let (host, port) = dest.as_str().split_at(pos); self.conditional_query_and_cache_override(host, host, port.parse::().unwrap_or(8448), cache) .await?; - Ok(FedDest::Named(host.to_owned(), port.to_owned())) + Ok(FedDest::Named( + host.to_owned(), + port.try_into().unwrap_or_else(|_| FedDest::default_port()), + )) } async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result { @@ -136,7 +139,10 @@ impl super::Service { let (host, port) = delegated.split_at(pos); self.conditional_query_and_cache_override(host, host, port.parse::().unwrap_or(8448), cache) .await?; - Ok(FedDest::Named(host.to_owned(), port.to_owned())) + Ok(FedDest::Named( + host.to_owned(), + port.try_into().unwrap_or_else(|_| FedDest::default_port()), + )) } async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDest) -> Result { @@ -145,7 +151,13 @@ impl super::Service { self.conditional_query_and_cache_override(&delegated, &overrider.hostname(), force_port.unwrap_or(8448), cache) .await?; if let Some(port) = force_port { - Ok(FedDest::Named(delegated, format!(":{port}"))) + Ok(FedDest::Named( + delegated, + format!(":{port}") + .as_str() + .try_into() + .unwrap_or_else(|_| FedDest::default_port()), + )) } else { Ok(add_port_to_hostname(&delegated)) } @@ -164,7 +176,11 @@ impl super::Service { self.conditional_query_and_cache_override(host, &overrider.hostname(), force_port.unwrap_or(8448), cache) .await?; if let Some(port) = force_port { - Ok(FedDest::Named(host.to_owned(), format!(":{port}"))) + let port = format!(":{port}"); + Ok(FedDest::Named( + host.to_owned(), + PortString::from(port.as_str()).unwrap_or_else(|_| FedDest::default_port()), + )) } else { Ok(add_port_to_hostname(host)) } @@ -269,7 +285,10 @@ impl super::Service { srv.iter().next().map(|result| { FedDest::Named( result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), + format!(":{}", result.port()) + .as_str() + .try_into() + .unwrap_or_else(|_| FedDest::default_port()), ) }) } diff --git a/src/service/resolver/fed.rs b/src/service/resolver/fed.rs index 79f71f13a..9c348b47e 100644 --- a/src/service/resolver/fed.rs +++ b/src/service/resolver/fed.rs @@ -4,12 +4,19 @@ use std::{ net::{IpAddr, SocketAddr}, }; +use arrayvec::ArrayString; + #[derive(Clone, Debug, PartialEq, Eq)] pub enum FedDest { Literal(SocketAddr), - Named(String, String), + Named(String, PortString), } +/// numeric or service-name +pub type PortString = ArrayString<16>; + +const DEFAULT_PORT: &str = ":8448"; + pub(crate) fn get_ip_with_port(dest_str: &str) -> Option { if let Ok(dest) = dest_str.parse::() { Some(FedDest::Literal(dest)) @@ -20,13 +27,16 @@ pub(crate) fn get_ip_with_port(dest_str: &str) -> Option { } } -pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), +pub(crate) fn add_port_to_hostname(dest: &str) -> FedDest { + let (host, port) = match dest.find(':') { + None => (dest, DEFAULT_PORT), + Some(pos) => dest.split_at(pos), }; - FedDest::Named(host.to_owned(), port.to_owned()) + FedDest::Named( + host.to_owned(), + PortString::from(port).unwrap_or_else(|_| FedDest::default_port()), + ) } impl FedDest { @@ -60,6 +70,10 @@ impl FedDest { Self::Named(_, port) => port[1..].parse().ok(), } } + + #[inline] + #[must_use] + pub fn default_port() -> PortString { PortString::from(DEFAULT_PORT).expect("default port string") } } impl fmt::Display for FedDest { diff --git a/src/service/resolver/tests.rs b/src/service/resolver/tests.rs index 55cf0345d..870f5eabf 100644 --- a/src/service/resolver/tests.rs +++ b/src/service/resolver/tests.rs @@ -30,7 +30,7 @@ fn ips_keep_custom_ports() { fn hostnames_get_default_ports() { assert_eq!( add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) + FedDest::Named(String::from("example.com"), ":8448".try_into().unwrap()) ); } @@ -38,6 +38,6 @@ fn hostnames_get_default_ports() { fn hostnames_keep_custom_ports() { assert_eq!( add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) + FedDest::Named(String::from("example.com"), ":1337".try_into().unwrap()) ); } From ad117641b88330aff3d1ba7ced57939df7862659 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 29 Oct 2024 00:08:41 +0000 Subject: [PATCH 156/245] add tuple-apply macro with length argument for now Signed-off-by: Jason Volk --- src/core/utils/mod.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 26b0484e0..b8640f3af 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -53,6 +53,25 @@ macro_rules! extract_variant { }; } +#[macro_export] +macro_rules! apply { + (1, $($idx:tt)+) => { + |t| (($($idx)+)(t.0),) + }; + + (2, $($idx:tt)+) => { + |t| (($($idx)+)(t.0), ($($idx)+)(t.1),) + }; + + (3, $($idx:tt)+) => { + |t| (($($idx)+)(t.0), ($($idx)+)(t.1), ($($idx)+)(t.2),) + }; + + (4, $($idx:tt)+) => { + |t| (($($idx)+)(t.0), ($($idx)+)(t.1), ($($idx)+)(t.2), ($($idx)+4)(t.3)) + }; +} + #[macro_export] macro_rules! at { ($idx:tt) => { @@ -112,6 +131,14 @@ macro_rules! is_not_empty { }; } +/// Functor for equality i.e. (a, b).map(is_equal!()) +#[macro_export] +macro_rules! is_equal { + () => { + |a, b| a == b + }; +} + /// Functor for truthy #[macro_export] macro_rules! is_true { From ed76797b55c8f32e11fcb0a3b8d0d29a4b93b6b8 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 29 Oct 2024 03:10:18 +0000 Subject: [PATCH 157/245] add raw_ overloads for prefix/from counting Signed-off-by: Jason Volk --- src/database/database.rs | 3 +++ src/database/map/count.rs | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/database/database.rs b/src/database/database.rs index 099df87d4..bf8c88555 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -33,6 +33,9 @@ impl Database { #[inline] pub fn iter(&self) -> impl Iterator + Send + '_ { self.maps.iter() } + #[inline] + pub fn keys(&self) -> impl Iterator + Send + '_ { self.maps.keys() } + #[inline] #[must_use] pub fn is_read_only(&self) -> bool { self.db.secondary || self.db.read_only } diff --git a/src/database/map/count.rs b/src/database/map/count.rs index dab45b7a9..3e92279c0 100644 --- a/src/database/map/count.rs +++ b/src/database/map/count.rs @@ -21,6 +21,18 @@ where self.keys_from_raw(from).count() } +/// Count the number of entries in the map starting from a lower-bound. +/// +/// - From is a raw +#[implement(super::Map)] +#[inline] +pub fn raw_count_from<'a, P>(&'a self, from: &'a P) -> impl Future + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_from(from).count() +} + /// Count the number of entries in the map matching a prefix. /// /// - Prefix is structured key @@ -32,3 +44,15 @@ where { self.keys_prefix_raw(prefix).count() } + +/// Count the number of entries in the map matching a prefix. +/// +/// - Prefix is raw +#[implement(super::Map)] +#[inline] +pub fn raw_count_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Future + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_prefix(prefix).count() +} From a7cb1c59518e8398e8a6aaedd784b9b7a222a2fd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 29 Oct 2024 23:31:53 +0000 Subject: [PATCH 158/245] slightly optimize request signing/verifying Signed-off-by: Jason Volk --- src/api/router/auth.rs | 40 +++++++----- src/service/sending/send.rs | 121 ++++++++++++++++++------------------ 2 files changed, 88 insertions(+), 73 deletions(-) diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 31e71f2ff..2552ddedc 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -220,20 +220,32 @@ async fn auth_server(services: &Services, request: &mut Request, body: Option<&C .expect("all requests have a path") .to_string(); - let signature: [Member; 1] = [(x_matrix.key.to_string(), Value::String(x_matrix.sig.to_string()))]; - let signatures: [Member; 1] = [(origin.to_string(), Value::Object(signature.into()))]; - let authorization: [Member; 5] = [ - ("destination".into(), Value::String(destination.into())), - ("method".into(), Value::String(request.parts.method.to_string())), - ("origin".into(), Value::String(origin.to_string())), - ("signatures".into(), Value::Object(signatures.into())), - ("uri".into(), Value::String(signature_uri)), - ]; + let signature: [Member; 1] = [(x_matrix.key.as_str().into(), Value::String(x_matrix.sig.to_string()))]; - let mut authorization: Object = authorization.into(); - if let Some(body) = body { - authorization.insert("content".to_owned(), body.clone()); - } + let signatures: [Member; 1] = [(origin.as_str().into(), Value::Object(signature.into()))]; + + let authorization: Object = if let Some(body) = body.cloned() { + let authorization: [Member; 6] = [ + ("content".into(), body), + ("destination".into(), Value::String(destination.into())), + ("method".into(), Value::String(request.parts.method.as_str().into())), + ("origin".into(), Value::String(origin.as_str().into())), + ("signatures".into(), Value::Object(signatures.into())), + ("uri".into(), Value::String(signature_uri)), + ]; + + authorization.into() + } else { + let authorization: [Member; 5] = [ + ("destination".into(), Value::String(destination.into())), + ("method".into(), Value::String(request.parts.method.as_str().into())), + ("origin".into(), Value::String(origin.as_str().into())), + ("signatures".into(), Value::Object(signatures.into())), + ("uri".into(), Value::String(signature_uri)), + ]; + + authorization.into() + }; let key = services .server_keys @@ -242,7 +254,7 @@ async fn auth_server(services: &Services, request: &mut Request, body: Option<&C .map_err(|e| err!(Request(Forbidden(warn!("Failed to fetch signing keys: {e}")))))?; let keys: PubKeys = [(x_matrix.key.to_string(), key.key)].into(); - let keys: PubKeyMap = [(origin.to_string(), keys)].into(); + let keys: PubKeyMap = [(origin.as_str().into(), keys)].into(); if let Err(e) = ruma::signatures::verify_json(&keys, authorization) { debug_error!("Failed to verify federation request from {origin}: {e}"); if request.parts.uri.to_string().contains('@') { diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 2fbb39190..939d6e73d 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -14,7 +14,7 @@ use ruma::{ }, serde::Base64, server_util::authorization::XMatrix, - ServerName, + CanonicalJsonObject, CanonicalJsonValue, ServerName, ServerSigningKeyId, }; use crate::{ @@ -74,7 +74,7 @@ impl super::Service { .try_into_http_request::>(actual.string().as_str(), SATIR, &VERSIONS) .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; - self.sign_request::(dest, &mut http_request); + self.sign_request(&mut http_request, dest); let request = Request::try_from(http_request)?; self.validate_url(request.url())?; @@ -178,68 +178,71 @@ where } #[implement(super::Service)] -fn sign_request(&self, dest: &ServerName, http_request: &mut http::Request>) -where - T: OutgoingRequest + Debug + Send, -{ - let mut req_map = serde_json::Map::with_capacity(8); - if !http_request.body().is_empty() { - req_map.insert( - "content".to_owned(), - serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), - ); +fn sign_request(&self, http_request: &mut http::Request>, dest: &ServerName) { + type Member = (String, Value); + type Value = CanonicalJsonValue; + type Object = CanonicalJsonObject; + + let origin = self.services.globals.server_name(); + let body = http_request.body(); + let uri = http_request + .uri() + .path_and_query() + .expect("http::Request missing path_and_query"); + + let mut req: Object = if !body.is_empty() { + let content: CanonicalJsonValue = serde_json::from_slice(body).expect("failed to serialize body"); + + let authorization: [Member; 5] = [ + ("content".into(), content), + ("destination".into(), dest.as_str().into()), + ("method".into(), http_request.method().as_str().into()), + ("origin".into(), origin.as_str().into()), + ("uri".into(), uri.to_string().into()), + ]; + + authorization.into() + } else { + let authorization: [Member; 4] = [ + ("destination".into(), dest.as_str().into()), + ("method".into(), http_request.method().as_str().into()), + ("origin".into(), origin.as_str().into()), + ("uri".into(), uri.to_string().into()), + ]; + + authorization.into() }; - req_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); - req_map.insert( - "uri".to_owned(), - http_request - .uri() - .path_and_query() - .expect("all requests have a path") - .to_string() - .into(), - ); - req_map.insert("origin".to_owned(), self.services.globals.server_name().to_string().into()); - req_map.insert("destination".to_owned(), dest.as_str().into()); - - let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap"); self.services .server_keys - .sign_json(&mut req_json) - .expect("our request json is what ruma expects"); - - let req_json: serde_json::Map = - serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap(); + .sign_json(&mut req) + .expect("request signing failed"); - let signatures = req_json["signatures"] + let signatures = req["signatures"] .as_object() - .expect("signatures object") + .and_then(|object| object[origin.as_str()].as_object()) + .expect("origin signatures object"); + + let key: &ServerSigningKeyId = signatures + .keys() + .next() + .map(|k| k.as_str().try_into()) + .expect("at least one signature from this origin") + .expect("keyid is json string"); + + let sig: Base64 = signatures .values() - .map(|v| { - v.as_object() - .expect("server signatures object") - .iter() - .map(|(k, v)| (k, v.as_str().expect("server signature string"))) - }); - - for signature_server in signatures { - for s in signature_server { - let key = - s.0.as_str() - .try_into() - .expect("valid homeserver signing key ID"); - let sig = Base64::parse(s.1).expect("valid base64"); - - http_request.headers_mut().insert( - AUTHORIZATION, - HeaderValue::from(&XMatrix::new( - self.services.globals.server_name().to_owned(), - dest.to_owned(), - key, - sig, - )), - ); - } - } + .next() + .map(|s| s.as_str().map(Base64::parse)) + .expect("at least one signature from this origin") + .expect("signature is json string") + .expect("signature is valid base64"); + + let x_matrix = XMatrix::new(origin.into(), dest.into(), key.into(), sig); + let authorization = HeaderValue::from(&x_matrix); + let authorization = http_request + .headers_mut() + .insert(AUTHORIZATION, authorization); + + debug_assert!(authorization.is_none(), "Authorization header already present"); } From 9775694423943135bc6015ebb102a21288ec05a1 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 30 Oct 2024 05:08:04 +0000 Subject: [PATCH 159/245] inline database stream interface functions lt 64B Signed-off-by: Jason Volk --- src/database/stream.rs | 3 +++ src/database/stream/items.rs | 2 ++ src/database/stream/items_rev.rs | 2 ++ src/database/stream/keys.rs | 3 +++ src/database/stream/keys_rev.rs | 3 +++ 5 files changed, 13 insertions(+) diff --git a/src/database/stream.rs b/src/database/stream.rs index d9b74215d..a2a72e44c 100644 --- a/src/database/stream.rs +++ b/src/database/stream.rs @@ -71,6 +71,7 @@ impl<'a> State<'a> { self } + #[inline] fn seek_fwd(&mut self) { if !exchange(&mut self.init, false) { self.inner.next(); @@ -79,6 +80,7 @@ impl<'a> State<'a> { } } + #[inline] fn seek_rev(&mut self) { if !exchange(&mut self.init, false) { self.inner.prev(); @@ -95,6 +97,7 @@ impl<'a> State<'a> { fn status(&self) -> Option { self.inner.status().map_err(map_err).err() } + #[inline] fn valid(&self) -> bool { self.inner.valid() } } diff --git a/src/database/stream/items.rs b/src/database/stream/items.rs index 31d5e9e8d..54f8bc5c9 100644 --- a/src/database/stream/items.rs +++ b/src/database/stream/items.rs @@ -28,6 +28,7 @@ impl<'a> Cursor<'a, KeyVal<'a>> for Items<'a> { fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + #[inline] fn seek(&mut self) { self.state.seek_fwd(); } } @@ -40,5 +41,6 @@ impl<'a> Stream for Items<'a> { } impl FusedStream for Items<'_> { + #[inline] fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } } diff --git a/src/database/stream/items_rev.rs b/src/database/stream/items_rev.rs index ab57a2506..26492db8c 100644 --- a/src/database/stream/items_rev.rs +++ b/src/database/stream/items_rev.rs @@ -28,6 +28,7 @@ impl<'a> Cursor<'a, KeyVal<'a>> for ItemsRev<'a> { fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + #[inline] fn seek(&mut self) { self.state.seek_rev(); } } @@ -40,5 +41,6 @@ impl<'a> Stream for ItemsRev<'a> { } impl FusedStream for ItemsRev<'_> { + #[inline] fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } } diff --git a/src/database/stream/keys.rs b/src/database/stream/keys.rs index 1c5d12e30..91884c8dc 100644 --- a/src/database/stream/keys.rs +++ b/src/database/stream/keys.rs @@ -26,8 +26,10 @@ impl<'a> Keys<'a> { impl<'a> Cursor<'a, Key<'a>> for Keys<'a> { fn state(&self) -> &State<'a> { &self.state } + #[inline] fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + #[inline] fn seek(&mut self) { self.state.seek_fwd(); } } @@ -40,5 +42,6 @@ impl<'a> Stream for Keys<'a> { } impl FusedStream for Keys<'_> { + #[inline] fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } } diff --git a/src/database/stream/keys_rev.rs b/src/database/stream/keys_rev.rs index 267074837..59f66c2e5 100644 --- a/src/database/stream/keys_rev.rs +++ b/src/database/stream/keys_rev.rs @@ -26,8 +26,10 @@ impl<'a> KeysRev<'a> { impl<'a> Cursor<'a, Key<'a>> for KeysRev<'a> { fn state(&self) -> &State<'a> { &self.state } + #[inline] fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + #[inline] fn seek(&mut self) { self.state.seek_rev(); } } @@ -40,5 +42,6 @@ impl<'a> Stream for KeysRev<'a> { } impl FusedStream for KeysRev<'_> { + #[inline] fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } } From 0eb67cfea00c19e7d0cf1981acc805986dfd05d6 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 30 Oct 2024 06:41:03 +0000 Subject: [PATCH 160/245] additional bool extensions for Result/Option conversion Signed-off-by: Jason Volk --- src/core/utils/bool.rs | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/core/utils/bool.rs b/src/core/utils/bool.rs index d5fa85aa8..e9f399d49 100644 --- a/src/core/utils/bool.rs +++ b/src/core/utils/bool.rs @@ -2,6 +2,23 @@ /// Boolean extensions and chain.starters pub trait BoolExt { + #[must_use] + fn clone_or(self, err: T, t: &T) -> T; + + #[must_use] + fn copy_or(self, err: T, t: T) -> T; + + #[must_use] + fn expect(self, msg: &str) -> Self; + + #[must_use] + fn expect_false(self, msg: &str) -> Self; + + fn into_option(self) -> Option<()>; + + #[allow(clippy::result_unit_err)] + fn into_result(self) -> Result<(), ()>; + fn map T>(self, f: F) -> T where Self: Sized; @@ -22,6 +39,24 @@ pub trait BoolExt { } impl BoolExt for bool { + #[inline] + fn clone_or(self, err: T, t: &T) -> T { self.map_or(err, || t.clone()) } + + #[inline] + fn copy_or(self, err: T, t: T) -> T { self.map_or(err, || t) } + + #[inline] + fn expect(self, msg: &str) -> Self { self.then_some(true).expect(msg) } + + #[inline] + fn expect_false(self, msg: &str) -> Self { (!self).then_some(false).expect(msg) } + + #[inline] + fn into_option(self) -> Option<()> { self.then_some(()) } + + #[inline] + fn into_result(self) -> Result<(), ()> { self.ok_or(()) } + #[inline] fn map T>(self, f: F) -> T where @@ -40,10 +75,10 @@ impl BoolExt for bool { fn map_or_else T>(self, err: F, f: F) -> T { self.then(f).unwrap_or_else(err) } #[inline] - fn ok_or(self, err: E) -> Result<(), E> { self.then_some(()).ok_or(err) } + fn ok_or(self, err: E) -> Result<(), E> { self.into_option().ok_or(err) } #[inline] - fn ok_or_else E>(self, err: F) -> Result<(), E> { self.then_some(()).ok_or_else(err) } + fn ok_or_else E>(self, err: F) -> Result<(), E> { self.into_option().ok_or_else(err) } #[inline] fn or T>(self, f: F) -> Option { (!self).then(f) } From 7fcc6d11a4993a174d4b0998e276bdc150594a15 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 30 Oct 2024 11:04:23 +0000 Subject: [PATCH 161/245] de-wrap state_accessor.server_can_see_event Signed-off-by: Jason Volk --- src/service/rooms/state_accessor/mod.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 561db18a5..a2cc27e85 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -160,20 +160,18 @@ impl Service { /// Whether a server is allowed to see an event through federation, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, origin, room_id, event_id))] - pub async fn server_can_see_event( - &self, origin: &ServerName, room_id: &RoomId, event_id: &EventId, - ) -> Result { + pub async fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> bool { let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { - return Ok(true); + return true; }; if let Some(visibility) = self .server_visibility_cache .lock() - .unwrap() + .expect("locked") .get_mut(&(origin.to_owned(), shortstatehash)) { - return Ok(*visibility); + return *visibility; } let history_visibility = self @@ -211,10 +209,10 @@ impl Service { self.server_visibility_cache .lock() - .unwrap() + .expect("locked") .insert((origin.to_owned(), shortstatehash), visibility); - Ok(visibility) + visibility } /// Whether a user is allowed to see an event, based on @@ -228,7 +226,7 @@ impl Service { if let Some(visibility) = self .user_visibility_cache .lock() - .unwrap() + .expect("locked") .get_mut(&(user_id.to_owned(), shortstatehash)) { return *visibility; @@ -262,7 +260,7 @@ impl Service { self.user_visibility_cache .lock() - .unwrap() + .expect("locked") .insert((user_id.to_owned(), shortstatehash), visibility); visibility From e49aee61c1f276fd613b8c61cde158b778c40763 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 30 Oct 2024 07:01:50 +0000 Subject: [PATCH 162/245] consolidate and parallelize api/server access check prologues Signed-off-by: Jason Volk --- Cargo.toml | 2 +- src/api/server/backfill.rs | 29 +++++--------- src/api/server/event.rs | 39 ++++++------------ src/api/server/event_auth.rs | 25 ++++-------- src/api/server/get_missing_events.rs | 32 +++++---------- src/api/server/mod.rs | 3 ++ src/api/server/state.rs | 27 +++++-------- src/api/server/state_ids.rs | 29 +++++--------- src/api/server/utils.rs | 60 ++++++++++++++++++++++++++++ 9 files changed, 123 insertions(+), 123 deletions(-) create mode 100644 src/api/server/utils.rs diff --git a/Cargo.toml b/Cargo.toml index e406c9e15..043790f8f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -213,7 +213,7 @@ features = [ [workspace.dependencies.futures] version = "0.3.30" default-features = false -features = ["std"] +features = ["std", "async-await"] [workspace.dependencies.tokio] version = "1.40.0" diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 088b891a2..281bf2a23 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -2,13 +2,13 @@ use std::cmp; use axum::extract::State; use conduit::{ - is_equal_to, utils::{IterStream, ReadyExt}, - Err, PduCount, Result, + PduCount, Result, }; use futures::{FutureExt, StreamExt}; use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch}; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/backfill/` @@ -18,24 +18,14 @@ use crate::Ruma; pub(crate) async fn get_backfill_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let until = body .v @@ -70,7 +60,6 @@ pub(crate) async fn get_backfill_route( .state_accessor .server_can_see_event(origin, &pdu.room_id, &pdu.event_id) .await - .is_ok_and(is_equal_to!(true)) { return None; } diff --git a/src/api/server/event.rs b/src/api/server/event.rs index 64ce3e401..29d5d8703 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,7 +1,8 @@ use axum::extract::State; -use conduit::{err, Err, Result}; +use conduit::{err, Result}; use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId}; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/event/{eventId}` @@ -20,35 +21,21 @@ pub(crate) async fn get_event_route( .await .map_err(|_| err!(Request(NotFound("Event not found."))))?; - let room_id_str = event + let room_id: &RoomId = event .get("room_id") .and_then(|val| val.as_str()) - .ok_or_else(|| err!(Database("Invalid event in database.")))?; + .ok_or_else(|| err!(Database("Invalid event in database.")))? + .try_into() + .map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - let room_id = - <&RoomId>::try_from(room_id_str).map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - - if !services - .rooms - .state_accessor - .is_world_readable(room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); - } - - if !services - .rooms - .state_accessor - .server_can_see_event(body.origin(), room_id, &body.event_id) - .await? - { - return Err!(Request(Forbidden("Server is not allowed to see event."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id, + event_id: Some(&body.event_id), } + .check() + .await?; Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 8fe96f813..faeb2b997 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -8,6 +8,7 @@ use ruma::{ RoomId, }; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` @@ -18,24 +19,14 @@ use crate::Ruma; pub(crate) async fn get_event_authorization_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let event = services .rooms diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index aee4fbe90..7dff44dcc 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -5,6 +5,7 @@ use ruma::{ CanonicalJsonValue, EventId, RoomId, }; +use super::AccessCheck; use crate::Ruma; /// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` @@ -13,29 +14,16 @@ use crate::Ruma; pub(crate) async fn get_missing_events_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; - let limit = body - .limit - .try_into() - .expect("UInt could not be converted to usize"); + let limit = body.limit.try_into()?; let mut queued_events = body.latest_events.clone(); // the vec will never have more entries the limit @@ -70,7 +58,7 @@ pub(crate) async fn get_missing_events_route( .rooms .state_accessor .server_can_see_event(body.origin(), &body.room_id, &queued_events[i]) - .await? + .await { i = i.saturating_add(1); continue; diff --git a/src/api/server/mod.rs b/src/api/server/mod.rs index 9a184f237..9b7d91cba 100644 --- a/src/api/server/mod.rs +++ b/src/api/server/mod.rs @@ -41,3 +41,6 @@ pub(super) use state_ids::*; pub(super) use user::*; pub(super) use version::*; pub(super) use well_known::*; + +mod utils; +use utils::AccessCheck; diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 59bb6c7b1..06a44a999 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,10 +1,11 @@ use std::borrow::Borrow; use axum::extract::State; -use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; +use conduit::{err, result::LogErr, utils::IterStream, Result}; use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::api::federation::event::get_room_state; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/state/{roomId}` @@ -13,24 +14,14 @@ use crate::Ruma; pub(crate) async fn get_room_state_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let shortstatehash = services .rooms diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 957a2a86e..52d8e7cca 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,11 +1,12 @@ use std::borrow::Borrow; use axum::extract::State; -use conduit::{err, Err}; +use conduit::{err, Result}; use futures::StreamExt; use ruma::api::federation::event::get_room_state_ids; -use crate::{Result, Ruma}; +use super::AccessCheck; +use crate::Ruma; /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// @@ -14,24 +15,14 @@ use crate::{Result, Ruma}; pub(crate) async fn get_room_state_ids_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let shortstatehash = services .rooms diff --git a/src/api/server/utils.rs b/src/api/server/utils.rs new file mode 100644 index 000000000..278465caa --- /dev/null +++ b/src/api/server/utils.rs @@ -0,0 +1,60 @@ +use conduit::{implement, is_false, Err, Result}; +use conduit_service::Services; +use futures::{future::OptionFuture, join, FutureExt}; +use ruma::{EventId, RoomId, ServerName}; + +pub(super) struct AccessCheck<'a> { + pub(super) services: &'a Services, + pub(super) origin: &'a ServerName, + pub(super) room_id: &'a RoomId, + pub(super) event_id: Option<&'a EventId>, +} + +#[implement(AccessCheck, params = "<'_>")] +pub(super) async fn check(&self) -> Result { + let acl_check = self + .services + .rooms + .event_handler + .acl_check(self.origin, self.room_id) + .map(|result| result.is_ok()); + + let world_readable = self + .services + .rooms + .state_accessor + .is_world_readable(self.room_id); + + let server_in_room = self + .services + .rooms + .state_cache + .server_in_room(self.origin, self.room_id); + + let server_can_see: OptionFuture<_> = self + .event_id + .map(|event_id| { + self.services + .rooms + .state_accessor + .server_can_see_event(self.origin, self.room_id, event_id) + }) + .into(); + + let (world_readable, server_in_room, server_can_see, acl_check) = + join!(world_readable, server_in_room, server_can_see, acl_check); + + if !acl_check { + return Err!(Request(Forbidden("Server access denied."))); + } + + if !world_readable && !server_in_room { + return Err!(Request(Forbidden("Server is not in room."))); + } + + if server_can_see.is_some_and(is_false!()) { + return Err!(Request(Forbidden("Server is not allowed to see event."))); + } + + Ok(()) +} From 6b0eb7608d06fbfce663e4775196cfd3c7bae643 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 31 Oct 2024 07:33:16 +0000 Subject: [PATCH 163/245] add Filter extension to Result Signed-off-by: Jason Volk --- src/core/utils/result.rs | 5 +++-- src/core/utils/result/filter.rs | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 src/core/utils/result/filter.rs diff --git a/src/core/utils/result.rs b/src/core/utils/result.rs index 9a60d19e2..fb1b7b959 100644 --- a/src/core/utils/result.rs +++ b/src/core/utils/result.rs @@ -1,4 +1,5 @@ mod debug_inspect; +mod filter; mod flat_ok; mod into_is_ok; mod log_debug_err; @@ -8,8 +9,8 @@ mod not_found; mod unwrap_infallible; pub use self::{ - debug_inspect::DebugInspect, flat_ok::FlatOk, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, log_err::LogErr, - map_expect::MapExpect, not_found::NotFound, unwrap_infallible::UnwrapInfallible, + debug_inspect::DebugInspect, filter::Filter, flat_ok::FlatOk, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, + log_err::LogErr, map_expect::MapExpect, not_found::NotFound, unwrap_infallible::UnwrapInfallible, }; pub type Result = std::result::Result; diff --git a/src/core/utils/result/filter.rs b/src/core/utils/result/filter.rs new file mode 100644 index 000000000..f11d36329 --- /dev/null +++ b/src/core/utils/result/filter.rs @@ -0,0 +1,21 @@ +use super::Result; + +pub trait Filter { + /// Similar to Option::filter + #[must_use] + fn filter(self, predicate: P) -> Self + where + P: FnOnce(&T) -> Result<(), U>, + E: From; +} + +impl Filter for Result { + #[inline] + fn filter(self, predicate: P) -> Self + where + P: FnOnce(&T) -> Result<(), U>, + E: From, + { + self.and_then(move |t| predicate(&t).map(move |()| t).map_err(Into::into)) + } +} From 0bc6fdd5897c9a216ced0433532eb47a97f142ac Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 31 Oct 2024 08:19:37 +0000 Subject: [PATCH 164/245] Refactor ShortStateInfo et al to properly named structures Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 9 ++- src/api/client/membership.rs | 14 ++++- src/service/rooms/event_handler/mod.rs | 15 +++-- src/service/rooms/short/mod.rs | 1 + src/service/rooms/state/mod.rs | 8 +-- src/service/rooms/state_accessor/data.rs | 6 +- src/service/rooms/state_compressor/mod.rs | 77 +++++++++++++++-------- 7 files changed, 89 insertions(+), 41 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index db1028588..754c98408 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -12,6 +12,7 @@ use ruma::{ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; +use service::rooms::state_compressor::HashSetCompressStateEvent; use tracing_subscriber::EnvFilter; use crate::admin_command; @@ -632,7 +633,11 @@ pub(super) async fn force_set_room_state_from_server( .await?; info!("Forcing new room state"); - let (short_state_hash, new, removed) = self + let HashSetCompressStateEvent { + shortstatehash: short_state_hash, + added, + removed, + } = self .services .rooms .state_compressor @@ -643,7 +648,7 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .state - .force_state(room_id.clone().as_ref(), short_state_hash, new, removed, &state_lock) + .force_state(room_id.clone().as_ref(), short_state_hash, added, removed, &state_lock) .await?; info!( diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 27de60c68..c41e93fa3 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -39,7 +39,11 @@ use ruma::{ state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; -use service::{appservice::RegistrationInfo, rooms::state::RoomMutexGuard, Services}; +use service::{ + appservice::RegistrationInfo, + rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent}, + Services, +}; use crate::{client::full_user_deactivate, Ruma}; @@ -941,7 +945,11 @@ async fn join_room_by_id_helper_remote( .await; debug!("Saving compressed state"); - let (statehash_before_join, new, removed) = services + let HashSetCompressStateEvent { + shortstatehash: statehash_before_join, + added, + removed, + } = services .rooms .state_compressor .save_state(room_id, Arc::new(compressed)) @@ -951,7 +959,7 @@ async fn join_room_by_id_helper_remote( services .rooms .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .force_state(room_id, statehash_before_join, added, removed, &state_lock) .await?; info!("Updating joined counts for new room"); diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index ec04e748e..adebd3323 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -33,8 +33,11 @@ use ruma::{ RoomId, RoomVersionId, ServerName, UserId, }; -use super::state_compressor::CompressedStateEvent; -use crate::{globals, rooms, sending, server_keys, Dep}; +use crate::{ + globals, rooms, + rooms::state_compressor::{CompressedStateEvent, HashSetCompressStateEvent}, + sending, server_keys, Dep, +}; pub struct Service { services: Services, @@ -692,7 +695,11 @@ impl Service { // Set the new room state to the resolved state debug!("Forcing new room state"); - let (sstatehash, new, removed) = self + let HashSetCompressStateEvent { + shortstatehash, + added, + removed, + } = self .services .state_compressor .save_state(room_id, new_room_state) @@ -700,7 +707,7 @@ impl Service { self.services .state - .force_state(room_id, sstatehash, new, removed, &state_lock) + .force_state(room_id, shortstatehash, added, removed, &state_lock) .await?; } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 02c449cc3..620116054 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -24,6 +24,7 @@ struct Services { globals: Dep, } +pub type ShortStateKey = ShortId; pub type ShortEventId = ShortId; pub type ShortRoomId = ShortId; pub type ShortId = u64; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 6abaa1980..34fab0798 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -182,12 +182,12 @@ impl Service { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed - .difference(&parent_stateinfo.1) + .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo - .1 + .full_state .difference(&state_ids_compressed) .copied() .collect(); @@ -259,7 +259,7 @@ impl Service { let replaces = states_parents .last() .map(|info| { - info.1 + info.full_state .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) }) @@ -421,7 +421,7 @@ impl Service { })? .pop() .expect("there is always one layer") - .1; + .full_state; let mut ret = HashMap::new(); for compressed in full_state.iter() { diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index adc26f000..f77a6d80b 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -45,7 +45,7 @@ impl Data { .map_err(|e| err!(Database("Missing state IDs: {e}")))? .pop() .expect("there is always one layer") - .1; + .full_state; let mut result = HashMap::new(); let mut i: u8 = 0; @@ -78,7 +78,7 @@ impl Data { .await? .pop() .expect("there is always one layer") - .1; + .full_state; let mut result = HashMap::new(); let mut i: u8 = 0; @@ -123,7 +123,7 @@ impl Data { .map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))? .pop() .expect("there is always one layer") - .1; + .full_state; let compressed = full_state .iter() diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index be66c5970..1f351f40a 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -10,7 +10,7 @@ use database::Map; use lru_cache::LruCache; use ruma::{EventId, RoomId}; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortId, Dep}; pub struct Service { pub stateinfo_cache: Mutex, @@ -27,24 +27,33 @@ struct Data { shortstatehash_statediff: Arc, } +#[derive(Clone)] struct StateDiff { parent: Option, added: Arc>, removed: Arc>, } +#[derive(Clone, Default)] +pub struct ShortStateInfo { + pub shortstatehash: ShortStateHash, + pub full_state: Arc>, + pub added: Arc>, + pub removed: Arc>, +} + +#[derive(Clone, Default)] +pub struct HashSetCompressStateEvent { + pub shortstatehash: ShortStateHash, + pub added: Arc>, + pub removed: Arc>, +} + +pub type ShortStateHash = ShortId; +pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; type StateInfoLruCache = LruCache; type ShortStateInfoVec = Vec; type ParentStatesVec = Vec; -type ShortStateInfo = ( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed -); - -type HashSetCompressStateEvent = (u64, Arc>, Arc>); -pub type CompressedStateEvent = [u8; 2 * size_of::()]; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -95,14 +104,19 @@ impl Service { if let Some(parent) = parent { let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; - let mut state = (*response.last().expect("at least one response").1).clone(); + let mut state = (*response.last().expect("at least one response").full_state).clone(); state.extend(added.iter().copied()); let removed = (*removed).clone(); for r in &removed { state.remove(r); } - response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); + response.push(ShortStateInfo { + shortstatehash, + full_state: Arc::new(state), + added, + removed: Arc::new(removed), + }); self.stateinfo_cache .lock() @@ -111,7 +125,13 @@ impl Service { Ok(response) } else { - let response = vec![(shortstatehash, added.clone(), added, removed)]; + let response = vec![ShortStateInfo { + shortstatehash, + full_state: added.clone(), + added, + removed, + }]; + self.stateinfo_cache .lock() .expect("locked") @@ -185,8 +205,8 @@ impl Service { // To many layers, we have to go deeper let parent = parent_states.pop().expect("parent must have a state"); - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { @@ -236,14 +256,14 @@ impl Service { // 2. We replace a layer above let parent = parent_states.pop().expect("parent must have a state"); - let parent_2_len = parent.2.len(); - let parent_3_len = parent.3.len(); - let parent_diff = checked!(parent_2_len + parent_3_len)?; + let parent_added_len = parent.added.len(); + let parent_removed_len = parent.removed.len(); + let parent_diff = checked!(parent_added_len + parent_removed_len)?; if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { // Diff too big, we replace above layer(s) - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { @@ -275,7 +295,7 @@ impl Service { self.save_statediff( shortstatehash, &StateDiff { - parent: Some(parent.0), + parent: Some(parent.shortstatehash), added: statediffnew, removed: statediffremoved, }, @@ -311,7 +331,10 @@ impl Service { .await; if Some(new_shortstatehash) == previous_shortstatehash { - return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); + return Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + ..Default::default() + }); } let states_parents = if let Some(p) = previous_shortstatehash { @@ -322,12 +345,12 @@ impl Service { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = new_state_ids_compressed - .difference(&parent_stateinfo.1) + .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo - .1 + .full_state .difference(&new_state_ids_compressed) .copied() .collect(); @@ -347,7 +370,11 @@ impl Service { )?; }; - Ok((new_shortstatehash, statediffnew, statediffremoved)) + Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + added: statediffnew, + removed: statediffremoved, + }) } async fn get_statediff(&self, shortstatehash: u64) -> Result { From f746be82c158e0bae2a3311953d49c6c8be2c910 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 31 Oct 2024 08:41:43 +0000 Subject: [PATCH 165/245] typename some loose u64 ShortId's Signed-off-by: Jason Volk --- src/service/rooms/short/mod.rs | 33 ++++++++++++----------- src/service/rooms/state_accessor/data.rs | 12 ++++----- src/service/rooms/state_compressor/mod.rs | 25 +++++++++-------- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 620116054..a903ef22a 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -24,6 +24,7 @@ struct Services { globals: Dep, } +pub type ShortStateHash = ShortId; pub type ShortStateKey = ShortId; pub type ShortEventId = ShortId; pub type ShortRoomId = ShortId; @@ -50,7 +51,7 @@ impl crate::Service for Service { } #[implement(Service)] -pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { +pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId { const BUFSIZE: usize = size_of::(); if let Ok(shorteventid) = self @@ -78,7 +79,7 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { } #[implement(Service)] -pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { +pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { self.db .eventid_shorteventid .get_batch_blocking(event_ids.iter()) @@ -106,7 +107,7 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> } #[implement(Service)] -pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { +pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { let key = (event_type, state_key); self.db .statekey_shortstatekey @@ -116,8 +117,8 @@ pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &s } #[implement(Service)] -pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { - const BUFSIZE: usize = size_of::(); +pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> ShortStateKey { + const BUFSIZE: usize = size_of::(); let key = (event_type, state_key); if let Ok(shortstatekey) = self @@ -145,8 +146,8 @@ pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, sta } #[implement(Service)] -pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - const BUFSIZE: usize = size_of::(); +pub async fn get_eventid_from_short(&self, shorteventid: ShortEventId) -> Result> { + const BUFSIZE: usize = size_of::(); self.db .shorteventid_eventid @@ -157,8 +158,8 @@ pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result Vec>> { - const BUFSIZE: usize = size_of::(); +pub async fn multi_get_eventid_from_short(&self, shorteventid: &[ShortEventId]) -> Vec>> { + const BUFSIZE: usize = size_of::(); let keys: Vec<[u8; BUFSIZE]> = shorteventid .iter() @@ -174,8 +175,8 @@ pub async fn multi_get_eventid_from_short(&self, shorteventid: &[u64]) -> Vec Result<(StateEventType, String)> { - const BUFSIZE: usize = size_of::(); +pub async fn get_statekey_from_short(&self, shortstatekey: ShortStateKey) -> Result<(StateEventType, String)> { + const BUFSIZE: usize = size_of::(); self.db .shortstatekey_statekey @@ -191,8 +192,8 @@ pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(State /// Returns (shortstatehash, already_existed) #[implement(Service)] -pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { - const BUFSIZE: usize = size_of::(); +pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (ShortStateHash, bool) { + const BUFSIZE: usize = size_of::(); if let Ok(shortstatehash) = self .db @@ -215,19 +216,19 @@ pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, boo } #[implement(Service)] -pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { +pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { self.db.roomid_shortroomid.get(room_id).await.deserialized() } #[implement(Service)] -pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { +pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> ShortRoomId { self.db .roomid_shortroomid .get(room_id) .await .deserialized() .unwrap_or_else(|_| { - const BUFSIZE: usize = size_of::(); + const BUFSIZE: usize = size_of::(); let short = self.services.globals.next_count().unwrap(); debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index f77a6d80b..9c96785f4 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -5,7 +5,7 @@ use database::{Deserialized, Map}; use futures::TryFutureExt; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortStateHash, Dep}; pub(super) struct Data { eventid_shorteventid: Arc, @@ -36,7 +36,7 @@ impl Data { } #[allow(unused_qualifications)] // async traits - pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + pub(super) async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result>> { let full_state = self .services .state_compressor @@ -69,7 +69,7 @@ impl Data { #[allow(unused_qualifications)] // async traits pub(super) async fn state_full( - &self, shortstatehash: u64, + &self, shortstatehash: ShortStateHash, ) -> Result>> { let full_state = self .services @@ -107,7 +107,7 @@ impl Data { /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). #[allow(clippy::unused_self)] pub(super) async fn state_get_id( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result> { let shortstatekey = self .services @@ -147,7 +147,7 @@ impl Data { /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). pub(super) async fn state_get( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result> { self.state_get_id(shortstatehash, event_type, state_key) .and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) @@ -155,7 +155,7 @@ impl Data { } /// Returns the state hash for this pdu. - pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { + pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { self.eventid_shorteventid .get(event_id) .and_then(|shorteventid| self.shorteventid_shortstatehash.get(&shorteventid)) diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 1f351f40a..e213490ba 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -10,7 +10,11 @@ use database::Map; use lru_cache::LruCache; use ruma::{EventId, RoomId}; -use crate::{rooms, rooms::short::ShortId, Dep}; +use crate::{ + rooms, + rooms::short::{ShortStateHash, ShortStateKey}, + Dep, +}; pub struct Service { pub stateinfo_cache: Mutex, @@ -49,9 +53,8 @@ pub struct HashSetCompressStateEvent { pub removed: Arc>, } -pub type ShortStateHash = ShortId; pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; -type StateInfoLruCache = LruCache; +type StateInfoLruCache = LruCache; type ShortStateInfoVec = Vec; type ParentStatesVec = Vec; @@ -86,7 +89,7 @@ impl crate::Service for Service { impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { + pub async fn load_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result { if let Some(r) = self .stateinfo_cache .lock() @@ -141,7 +144,7 @@ impl Service { } } - pub async fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> CompressedStateEvent { + pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( &self @@ -159,7 +162,7 @@ impl Service { #[inline] pub async fn parse_compressed_state_event( &self, compressed_event: &CompressedStateEvent, - ) -> Result<(u64, Arc)> { + ) -> Result<(ShortStateKey, Arc)> { use utils::u64_from_u8; let shortstatekey = u64_from_u8(&compressed_event[0..size_of::()]); @@ -192,7 +195,7 @@ impl Service { /// added diff and removed diff for each parent layer #[tracing::instrument(skip_all, level = "debug")] pub fn save_state_from_diff( - &self, shortstatehash: u64, statediffnew: Arc>, + &self, shortstatehash: ShortStateHash, statediffnew: Arc>, statediffremoved: Arc>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, ) -> Result { @@ -377,9 +380,9 @@ impl Service { }) } - async fn get_statediff(&self, shortstatehash: u64) -> Result { - const BUFSIZE: usize = size_of::(); - const STRIDE: usize = size_of::(); + async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result { + const BUFSIZE: usize = size_of::(); + const STRIDE: usize = size_of::(); let value = self .db @@ -418,7 +421,7 @@ impl Service { }) } - fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { + fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) { let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); for new in diff.added.iter() { value.extend_from_slice(&new[..]); From 1f1e2d547cceda68da71855186abc38a3a1ba713 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 31 Oct 2024 11:49:00 +0000 Subject: [PATCH 166/245] optimize override ips; utilize all ips from cache Signed-off-by: Jason Volk --- src/service/resolver/actual.rs | 6 +++--- src/service/resolver/cache.rs | 10 +++++++--- src/service/resolver/dns.rs | 12 +++++------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 61eedca51..5dc03d141 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -10,7 +10,7 @@ use ipaddress::IPAddress; use ruma::ServerName; use super::{ - cache::{CachedDest, CachedOverride}, + cache::{CachedDest, CachedOverride, MAX_IPS}, fed::{add_port_to_hostname, get_ip_with_port, FedDest, PortString}, }; @@ -266,9 +266,9 @@ impl super::Service { } self.set_cached_override( - overname.to_owned(), + overname, CachedOverride { - ips: override_ip.iter().collect(), + ips: override_ip.into_iter().take(MAX_IPS).collect(), port, expire: CachedOverride::default_expire(), }, diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index 465b59855..a13399dc8 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -5,6 +5,7 @@ use std::{ time::SystemTime, }; +use arrayvec::ArrayVec; use conduit::{trace, utils::rand}; use ruma::{OwnedServerName, ServerName}; @@ -24,7 +25,7 @@ pub struct CachedDest { #[derive(Clone, Debug)] pub struct CachedOverride { - pub ips: Vec, + pub ips: IpAddrs, pub port: u16, pub expire: SystemTime, } @@ -32,6 +33,9 @@ pub struct CachedOverride { pub type WellKnownMap = HashMap; pub type TlsNameMap = HashMap; +pub type IpAddrs = ArrayVec; +pub(crate) const MAX_IPS: usize = 3; + impl Cache { pub(super) fn new() -> Arc { Arc::new(Self { @@ -61,13 +65,13 @@ impl super::Service { .cloned() } - pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option { + pub fn set_cached_override(&self, name: &str, over: CachedOverride) -> Option { trace!(?name, ?over, "set cached override"); self.cache .overrides .write() .expect("locked for writing") - .insert(name, over) + .insert(name.into(), over) } #[must_use] diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index 89129e03e..d3e9f5c93 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -1,4 +1,4 @@ -use std::{iter, net::SocketAddr, sync::Arc, time::Duration}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use conduit::{err, Result, Server}; use futures::FutureExt; @@ -101,14 +101,12 @@ impl Resolve for Hooked { } async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult { - let first_ip = cached + let addrs = cached .ips - .first() - .expect("must provide at least one override"); - - let saddr = SocketAddr::new(*first_ip, cached.port); + .into_iter() + .map(move |ip| SocketAddr::new(ip, cached.port)); - Ok(Box::new(iter::once(saddr))) + Ok(Box::new(addrs)) } async fn resolve_to_reqwest(resolver: Arc, name: Name) -> ResolvingResult { From ba1c13468942dce17d78e394ffabe2796cfed577 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 1 Nov 2024 22:16:14 +0000 Subject: [PATCH 167/245] move migrations out of globals service Signed-off-by: Jason Volk --- src/service/globals/mod.rs | 1 - src/service/media/migrations.rs | 4 ++-- src/service/{globals => }/migrations.rs | 0 src/service/mod.rs | 1 + src/service/services.rs | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename src/service/{globals => }/migrations.rs (100%) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 0a7dda9f2..bd9569642 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,5 +1,4 @@ mod data; -pub(super) mod migrations; use std::{ collections::HashMap, diff --git a/src/service/media/migrations.rs b/src/service/media/migrations.rs index 0e358d443..f1c6da7d8 100644 --- a/src/service/media/migrations.rs +++ b/src/service/media/migrations.rs @@ -13,7 +13,7 @@ use conduit::{ warn, Config, Result, }; -use crate::{globals, Services}; +use crate::{migrations, Services}; /// Migrates a media directory from legacy base64 file names to sha2 file names. /// All errors are fatal. Upon success the database is keyed to not perform this @@ -50,7 +50,7 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services.globals.db.database_version().await == 14 && globals::migrations::DATABASE_VERSION == 13 { + if services.globals.db.database_version().await == 14 && migrations::DATABASE_VERSION == 13 { services.globals.db.bump_database_version(13)?; } diff --git a/src/service/globals/migrations.rs b/src/service/migrations.rs similarity index 100% rename from src/service/globals/migrations.rs rename to src/service/migrations.rs diff --git a/src/service/mod.rs b/src/service/mod.rs index 604e34045..c7dcc0c61 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,6 +1,7 @@ #![allow(refining_impl_trait)] mod manager; +mod migrations; mod service; pub mod services; diff --git a/src/service/services.rs b/src/service/services.rs index c0af42499..b86e7a721 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -114,7 +114,7 @@ impl Services { debug_info!("Starting services..."); self.admin.set_services(Some(Arc::clone(self)).as_ref()); - globals::migrations::migrations(self).await?; + super::migrations::migrations(self).await?; self.manager .lock() .await From 87424370364fa391214d4a55feaf9d3606add9e0 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 1 Nov 2024 22:43:26 +0000 Subject: [PATCH 168/245] wrap unimplemented ser/de branches with internal macro Signed-off-by: Jason Volk --- src/database/de.rs | 34 ++++++++++++++++--------------- src/database/mod.rs | 2 +- src/database/ser.rs | 48 ++++++++++++++++++++++---------------------- src/database/util.rs | 23 +++++++++++++++++++++ 4 files changed, 66 insertions(+), 41 deletions(-) diff --git a/src/database/de.rs b/src/database/de.rs index 0e074fdba..d7dc11022 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -5,6 +5,8 @@ use serde::{ Deserialize, }; +use crate::util::unhandled; + /// Deserialize into T from buffer. pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result where @@ -192,7 +194,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { match name { "Ignore" => self.record_ignore(), "IgnoreAll" => self.record_ignore_all(), - _ => unimplemented!("Unrecognized deserialization Directive {name:?}"), + _ => unhandled!("Unrecognized deserialization Directive {name:?}"), }; visitor.visit_unit() @@ -214,27 +216,27 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - unimplemented!("deserialize Enum not implemented") + unhandled!("deserialize Enum not implemented") } fn deserialize_option>(self, _visitor: V) -> Result { - unimplemented!("deserialize Option not implemented") + unhandled!("deserialize Option not implemented") } fn deserialize_bool>(self, _visitor: V) -> Result { - unimplemented!("deserialize bool not implemented") + unhandled!("deserialize bool not implemented") } fn deserialize_i8>(self, _visitor: V) -> Result { - unimplemented!("deserialize i8 not implemented") + unhandled!("deserialize i8 not implemented") } fn deserialize_i16>(self, _visitor: V) -> Result { - unimplemented!("deserialize i16 not implemented") + unhandled!("deserialize i16 not implemented") } fn deserialize_i32>(self, _visitor: V) -> Result { - unimplemented!("deserialize i32 not implemented") + unhandled!("deserialize i32 not implemented") } fn deserialize_i64>(self, visitor: V) -> Result { @@ -244,15 +246,15 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_u8>(self, _visitor: V) -> Result { - unimplemented!("deserialize u8 not implemented; try dereferencing the Handle for [u8] access instead") + unhandled!("deserialize u8 not implemented; try dereferencing the Handle for [u8] access instead") } fn deserialize_u16>(self, _visitor: V) -> Result { - unimplemented!("deserialize u16 not implemented") + unhandled!("deserialize u16 not implemented") } fn deserialize_u32>(self, _visitor: V) -> Result { - unimplemented!("deserialize u32 not implemented") + unhandled!("deserialize u32 not implemented") } fn deserialize_u64>(self, visitor: V) -> Result { @@ -262,15 +264,15 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_f32>(self, _visitor: V) -> Result { - unimplemented!("deserialize f32 not implemented") + unhandled!("deserialize f32 not implemented") } fn deserialize_f64>(self, _visitor: V) -> Result { - unimplemented!("deserialize f64 not implemented") + unhandled!("deserialize f64 not implemented") } fn deserialize_char>(self, _visitor: V) -> Result { - unimplemented!("deserialize char not implemented") + unhandled!("deserialize char not implemented") } fn deserialize_str>(self, visitor: V) -> Result { @@ -291,11 +293,11 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_byte_buf>(self, _visitor: V) -> Result { - unimplemented!("deserialize Byte Buf not implemented") + unhandled!("deserialize Byte Buf not implemented") } fn deserialize_unit>(self, _visitor: V) -> Result { - unimplemented!("deserialize Unit not implemented") + unhandled!("deserialize Unit not implemented") } // this only used for $serde_json::private::RawValue at this time; see MapAccess @@ -305,7 +307,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_ignored_any>(self, _visitor: V) -> Result { - unimplemented!("deserialize Ignored Any not implemented") + unhandled!("deserialize Ignored Any not implemented") } fn deserialize_any>(self, visitor: V) -> Result { diff --git a/src/database/mod.rs b/src/database/mod.rs index dcd66a1ee..f09c4a712 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -11,7 +11,7 @@ mod opts; mod ser; mod stream; mod tests; -mod util; +pub(crate) mod util; mod watchers; pub(crate) use self::{ diff --git a/src/database/ser.rs b/src/database/ser.rs index 0cc5c886c..961d2700b 100644 --- a/src/database/ser.rs +++ b/src/database/ser.rs @@ -4,6 +4,8 @@ use arrayvec::ArrayVec; use conduit::{debug::type_name, err, result::DebugInspect, utils::exchange, Error, Result}; use serde::{ser, Serialize}; +use crate::util::unhandled; + #[inline] pub fn serialize_to_array(val: T) -> Result> where @@ -146,17 +148,15 @@ impl ser::Serializer for &mut Serializer<'_, W> { fn serialize_tuple_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, ) -> Result { - unimplemented!("serialize Tuple Variant not implemented") + unhandled!("serialize Tuple Variant not implemented") } fn serialize_map(self, _len: Option) -> Result { - unimplemented!( - "serialize Map not implemented; did you mean to use database::Json() around your serde_json::Value?" - ) + unhandled!("serialize Map not implemented; did you mean to use database::Json() around your serde_json::Value?") } fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - unimplemented!( + unhandled!( "serialize Struct not implemented at this time; did you mean to use database::Json() around your struct?" ) } @@ -164,7 +164,7 @@ impl ser::Serializer for &mut Serializer<'_, W> { fn serialize_struct_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, ) -> Result { - unimplemented!("serialize Struct Variant not implemented") + unhandled!("serialize Struct Variant not implemented") } #[allow(clippy::needless_borrows_for_generic_args)] // buggy @@ -179,14 +179,14 @@ impl ser::Serializer for &mut Serializer<'_, W> { match name { "Json" => serde_json::to_writer(&mut self.out, value).map_err(Into::into), - _ => unimplemented!("Unrecognized serialization Newtype {name:?}"), + _ => unhandled!("Unrecognized serialization Newtype {name:?}"), } } fn serialize_newtype_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _value: &T, ) -> Result { - unimplemented!("serialize Newtype Variant not implemented") + unhandled!("serialize Newtype Variant not implemented") } fn serialize_unit_struct(self, name: &'static str) -> Result { @@ -197,14 +197,14 @@ impl ser::Serializer for &mut Serializer<'_, W> { "Separator" => { self.separator()?; }, - _ => unimplemented!("Unrecognized serialization directive: {name:?}"), + _ => unhandled!("Unrecognized serialization directive: {name:?}"), }; Ok(()) } fn serialize_unit_variant(self, _name: &'static str, _idx: u32, _var: &'static str) -> Result { - unimplemented!("serialize Unit Variant not implemented") + unhandled!("serialize Unit Variant not implemented") } fn serialize_some(self, val: &T) -> Result { val.serialize(self) } @@ -234,29 +234,29 @@ impl ser::Serializer for &mut Serializer<'_, W> { self.write(v) } - fn serialize_f64(self, _v: f64) -> Result { unimplemented!("serialize f64 not implemented") } + fn serialize_f64(self, _v: f64) -> Result { unhandled!("serialize f64 not implemented") } - fn serialize_f32(self, _v: f32) -> Result { unimplemented!("serialize f32 not implemented") } + fn serialize_f32(self, _v: f32) -> Result { unhandled!("serialize f32 not implemented") } fn serialize_i64(self, v: i64) -> Result { self.write(&v.to_be_bytes()) } fn serialize_i32(self, v: i32) -> Result { self.write(&v.to_be_bytes()) } - fn serialize_i16(self, _v: i16) -> Result { unimplemented!("serialize i16 not implemented") } + fn serialize_i16(self, _v: i16) -> Result { unhandled!("serialize i16 not implemented") } - fn serialize_i8(self, _v: i8) -> Result { unimplemented!("serialize i8 not implemented") } + fn serialize_i8(self, _v: i8) -> Result { unhandled!("serialize i8 not implemented") } fn serialize_u64(self, v: u64) -> Result { self.write(&v.to_be_bytes()) } fn serialize_u32(self, v: u32) -> Result { self.write(&v.to_be_bytes()) } - fn serialize_u16(self, _v: u16) -> Result { unimplemented!("serialize u16 not implemented") } + fn serialize_u16(self, _v: u16) -> Result { unhandled!("serialize u16 not implemented") } fn serialize_u8(self, v: u8) -> Result { self.write(&[v]) } - fn serialize_bool(self, _v: bool) -> Result { unimplemented!("serialize bool not implemented") } + fn serialize_bool(self, _v: bool) -> Result { unhandled!("serialize bool not implemented") } - fn serialize_unit(self) -> Result { unimplemented!("serialize unit not implemented") } + fn serialize_unit(self) -> Result { unhandled!("serialize unit not implemented") } } impl ser::SerializeSeq for &mut Serializer<'_, W> { @@ -309,14 +309,14 @@ impl ser::SerializeMap for &mut Serializer<'_, W> { type Ok = (); fn serialize_key(&mut self, _key: &T) -> Result { - unimplemented!("serialize Map Key not implemented") + unhandled!("serialize Map Key not implemented") } fn serialize_value(&mut self, _val: &T) -> Result { - unimplemented!("serialize Map Val not implemented") + unhandled!("serialize Map Val not implemented") } - fn end(self) -> Result { unimplemented!("serialize Map End not implemented") } + fn end(self) -> Result { unhandled!("serialize Map End not implemented") } } impl ser::SerializeStruct for &mut Serializer<'_, W> { @@ -324,10 +324,10 @@ impl ser::SerializeStruct for &mut Serializer<'_, W> { type Ok = (); fn serialize_field(&mut self, _key: &'static str, _val: &T) -> Result { - unimplemented!("serialize Struct Field not implemented") + unhandled!("serialize Struct Field not implemented") } - fn end(self) -> Result { unimplemented!("serialize Struct End not implemented") } + fn end(self) -> Result { unhandled!("serialize Struct End not implemented") } } impl ser::SerializeStructVariant for &mut Serializer<'_, W> { @@ -335,8 +335,8 @@ impl ser::SerializeStructVariant for &mut Serializer<'_, W> { type Ok = (); fn serialize_field(&mut self, _key: &'static str, _val: &T) -> Result { - unimplemented!("serialize Struct Variant Field not implemented") + unhandled!("serialize Struct Variant Field not implemented") } - fn end(self) -> Result { unimplemented!("serialize Struct Variant End not implemented") } + fn end(self) -> Result { unhandled!("serialize Struct Variant End not implemented") } } diff --git a/src/database/util.rs b/src/database/util.rs index d36e183f4..ae0763812 100644 --- a/src/database/util.rs +++ b/src/database/util.rs @@ -1,6 +1,29 @@ use conduit::{err, Result}; use rocksdb::{Direction, IteratorMode}; +//#[cfg(debug_assertions)] +macro_rules! unhandled { + ($msg:literal) => { + unimplemented!($msg) + }; +} + +// activate when stable; we're not ready for this yet +#[cfg(disable)] // #[cfg(not(debug_assertions))] +macro_rules! unhandled { + ($msg:literal) => { + // SAFETY: Eliminates branches for serializing and deserializing types never + // encountered in the codebase. This can promote optimization and reduce + // codegen. The developer must verify for every invoking callsite that the + // unhandled type is in no way involved and could not possibly be encountered. + unsafe { + std::hint::unreachable_unchecked(); + } + }; +} + +pub(crate) use unhandled; + #[inline] pub(crate) fn _into_direction(mode: &IteratorMode<'_>) -> Direction { use Direction::{Forward, Reverse}; From f191b4bad4cc6d15584f6c33cac57925e7b67abf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 2 Nov 2024 04:54:28 +0000 Subject: [PATCH 169/245] add map_expect for stream Signed-off-by: Jason Volk --- src/core/utils/result/map_expect.rs | 8 ++++---- src/core/utils/stream/expect.rs | 11 ++++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/core/utils/result/map_expect.rs b/src/core/utils/result/map_expect.rs index 8ce9195fe..9cd498f7f 100644 --- a/src/core/utils/result/map_expect.rs +++ b/src/core/utils/result/map_expect.rs @@ -2,14 +2,14 @@ use std::fmt::Debug; use super::Result; -pub trait MapExpect { +pub trait MapExpect<'a, T> { /// Calls expect(msg) on the mapped Result value. This is similar to /// map(Result::unwrap) but composes an expect call and message without /// requiring a closure. - fn map_expect(self, msg: &str) -> Option; + fn map_expect(self, msg: &'a str) -> T; } -impl MapExpect for Option> { +impl<'a, T, E: Debug> MapExpect<'a, Option> for Option> { #[inline] - fn map_expect(self, msg: &str) -> Option { self.map(|result| result.expect(msg)) } + fn map_expect(self, msg: &'a str) -> Option { self.map(|result| result.expect(msg)) } } diff --git a/src/core/utils/stream/expect.rs b/src/core/utils/stream/expect.rs index 3ab7181a8..68ac24ced 100644 --- a/src/core/utils/stream/expect.rs +++ b/src/core/utils/stream/expect.rs @@ -4,14 +4,19 @@ use crate::Result; pub trait TryExpect<'a, Item> { fn expect_ok(self) -> impl Stream + Send + 'a; + + fn map_expect(self, msg: &'a str) -> impl Stream + Send + 'a; } impl<'a, T, Item> TryExpect<'a, Item> for T where T: Stream> + TryStream + Send + 'a, + Item: 'a, { #[inline] - fn expect_ok(self: T) -> impl Stream + Send + 'a { - self.map(|res| res.expect("stream expectation failure")) - } + fn expect_ok(self: T) -> impl Stream + Send + 'a { self.map_expect("stream expectation failure") } + + //TODO: move to impl MapExpect + #[inline] + fn map_expect(self, msg: &'a str) -> impl Stream + Send + 'a { self.map(|res| res.expect(msg)) } } From 52f09fdb51f895940a1d0895cd05775ed9aeacbd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 2 Nov 2024 01:59:06 +0000 Subject: [PATCH 170/245] add database migration for missing referencedevents separator Signed-off-by: Jason Volk --- src/service/migrations.rs | 59 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/src/service/migrations.rs b/src/service/migrations.rs index c953e7b1d..45323fa23 100644 --- a/src/service/migrations.rs +++ b/src/service/migrations.rs @@ -1,7 +1,12 @@ +use std::cmp; + use conduit::{ - debug_info, debug_warn, error, info, + debug, debug_info, debug_warn, error, info, result::NotFound, - utils::{stream::TryIgnore, IterStream, ReadyExt}, + utils::{ + stream::{TryExpect, TryIgnore}, + IterStream, ReadyExt, + }, warn, Err, Result, }; use futures::{FutureExt, StreamExt}; @@ -120,6 +125,14 @@ async fn migrate(services: &Services) -> Result<()> { retroactively_fix_bad_data_from_roomuserid_joined(services).await?; } + if db["global"] + .get(b"fix_referencedevents_missing_sep") + .await + .is_not_found() + { + fix_referencedevents_missing_sep(services).await?; + } + let version_match = services.globals.db.database_version().await == DATABASE_VERSION || services.globals.db.database_version().await == CONDUIT_DATABASE_VERSION; @@ -444,3 +457,45 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) info!("Finished fixing"); Ok(()) } + +async fn fix_referencedevents_missing_sep(services: &Services) -> Result { + warn!("Fixing missing record separator between room_id and event_id in referencedevents"); + + let db = &services.db; + let cork = db.cork_and_sync(); + + let referencedevents = db["referencedevents"].clone(); + + let totals: (usize, usize) = (0, 0); + let (total, fixed) = referencedevents + .raw_stream() + .expect_ok() + .enumerate() + .ready_fold(totals, |mut a, (i, (key, val))| { + debug_assert!(val.is_empty(), "expected no value"); + + let has_sep = key.contains(&database::SEP); + + if !has_sep { + let key_str = std::str::from_utf8(key).expect("key not utf-8"); + let room_id_len = key_str.find('$').expect("missing '$' in key"); + let (room_id, event_id) = key_str.split_at(room_id_len); + debug!(?a, "fixing {room_id}, {event_id}"); + + let new_key = (room_id, event_id); + referencedevents.put_raw(new_key, val); + referencedevents.remove(key); + } + + a.0 = cmp::max(i, a.0); + a.1 = a.1.saturating_add((!has_sep).into()); + a + }) + .await; + + drop(cork); + info!(?total, ?fixed, "Fixed missing record separators in 'referencedevents'."); + + db["global"].insert(b"fix_referencedevents_missing_sep", []); + db.db.cleanup() +} From 8d251003a25aa697de052f01515e5cc23ce999e4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 3 Nov 2024 12:42:43 +0000 Subject: [PATCH 171/245] reduce Error-related codegen; add PoisonError Signed-off-by: Jason Volk --- src/core/config/mod.rs | 1 + src/core/error/err.rs | 1 + src/core/error/mod.rs | 10 +++++++++- src/core/error/panic.rs | 4 ++++ src/core/error/response.rs | 1 + 5 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index a6216da20..43cca4b8a 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -2001,6 +2001,7 @@ fn default_rocksdb_stats_level() -> u8 { 1 } // I know, it's a great name #[must_use] +#[inline] pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } fn default_ip_range_denylist() -> Vec { diff --git a/src/core/error/err.rs b/src/core/error/err.rs index 82bb40b05..baeb992d2 100644 --- a/src/core/error/err.rs +++ b/src/core/error/err.rs @@ -137,6 +137,7 @@ macro_rules! err_log { let visit = &mut |vs: ValueSet<'_>| { struct Visitor<'a>(&'a mut String); impl Visit for Visitor<'_> { + #[inline] fn record_debug(&mut self, field: &Field, val: &dyn fmt::Debug) { if field.name() == "message" { write!(self.0, "{:?}", val).expect("stream error"); diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 42250a0c6..302d0f87e 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -4,7 +4,7 @@ mod panic; mod response; mod serde; -use std::{any::Any, borrow::Cow, convert::Infallible, fmt}; +use std::{any::Any, borrow::Cow, convert::Infallible, fmt, sync::PoisonError}; pub use self::log::*; use crate::error; @@ -59,6 +59,8 @@ pub enum Error { JsTryFromInt(#[from] ruma::JsTryFromIntError), // js_int re-export #[error(transparent)] Path(#[from] axum::extract::rejection::PathRejection), + #[error("Mutex poisoned: {0}")] + Poison(Cow<'static, str>), #[error("Regex error: {0}")] Regex(#[from] regex::Error), #[error("Request error: {0}")] @@ -184,6 +186,12 @@ impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.message()) } } +impl From> for Error { + #[cold] + #[inline(never)] + fn from(e: PoisonError) -> Self { Self::Poison(e.to_string().into()) } +} + #[allow(clippy::fallible_impl_from)] impl From for Error { #[cold] diff --git a/src/core/error/panic.rs b/src/core/error/panic.rs index c070f7866..bec25132d 100644 --- a/src/core/error/panic.rs +++ b/src/core/error/panic.rs @@ -10,11 +10,14 @@ impl UnwindSafe for Error {} impl RefUnwindSafe for Error {} impl Error { + #[inline] pub fn panic(self) -> ! { panic_any(self.into_panic()) } #[must_use] + #[inline] pub fn from_panic(e: Box) -> Self { Self::Panic(debug::panic_str(&e), e) } + #[inline] pub fn into_panic(self) -> Box { match self { Self::Panic(_, e) | Self::PanicAny(e) => e, @@ -24,6 +27,7 @@ impl Error { } /// Get the panic message string. + #[inline] pub fn panic_str(self) -> Option<&'static str> { self.is_panic() .then_some(debug::panic_str(&self.into_panic())) diff --git a/src/core/error/response.rs b/src/core/error/response.rs index 7568a1c01..21fbdcf22 100644 --- a/src/core/error/response.rs +++ b/src/core/error/response.rs @@ -26,6 +26,7 @@ impl axum::response::IntoResponse for Error { } impl From for UiaaResponse { + #[inline] fn from(error: Error) -> Self { if let Error::Uiaa(uiaainfo) = error { return Self::AuthResponse(uiaainfo); From 768e81741cbd2bb16e18edb0782b64d270a86dfd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 3 Nov 2024 11:22:38 +0000 Subject: [PATCH 172/245] use FnMut for ready_try_for_each extension Signed-off-by: Jason Volk --- src/core/utils/stream/try_ready.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs index feb380675..3fbcbc454 100644 --- a/src/core/utils/stream/try_ready.rs +++ b/src/core/utils/stream/try_ready.rs @@ -24,7 +24,7 @@ where self, f: F, ) -> TryForEach>, impl FnMut(S::Ok) -> Ready>> where - F: Fn(S::Ok) -> Result<(), E>; + F: FnMut(S::Ok) -> Result<(), E>; } impl TryReadyExt for S @@ -42,10 +42,10 @@ where #[inline] fn ready_try_for_each( - self, f: F, + self, mut f: F, ) -> TryForEach>, impl FnMut(S::Ok) -> Ready>> where - F: Fn(S::Ok) -> Result<(), E>, + F: FnMut(S::Ok) -> Result<(), E>, { self.try_for_each(move |t| ready(f(t))) } From 4a94a4c945740b3a5ee605af61397f37060d91bc Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 4 Nov 2024 18:20:32 +0000 Subject: [PATCH 173/245] rename pdu/id to pdu/event_id Signed-off-by: Jason Volk --- src/core/pdu/{id.rs => event_id.rs} | 0 src/core/pdu/mod.rs | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename src/core/pdu/{id.rs => event_id.rs} (100%) diff --git a/src/core/pdu/id.rs b/src/core/pdu/event_id.rs similarity index 100% rename from src/core/pdu/id.rs rename to src/core/pdu/event_id.rs diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 9c3aaf9b6..53fcd0a95 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -2,8 +2,8 @@ mod builder; mod content; mod count; mod event; +mod event_id; mod filter; -mod id; mod redact; mod strip; mod unsigned; @@ -20,7 +20,7 @@ pub use self::{ builder::{Builder, Builder as PduBuilder}, count::PduCount, event::Event, - id::*, + event_id::*, }; use crate::Result; From 78aeb620bc4ecd1d2feadd72043a08a037615553 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 3 Nov 2024 18:23:35 +0000 Subject: [PATCH 174/245] add broad timeout on acquire_origins keys operation Signed-off-by: Jason Volk --- src/service/server_keys/acquire.rs | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs index 25b676b8f..cdaf28b4a 100644 --- a/src/service/server_keys/acquire.rs +++ b/src/service/server_keys/acquire.rs @@ -1,15 +1,17 @@ use std::{ borrow::Borrow, collections::{BTreeMap, BTreeSet}, + time::Duration, }; -use conduit::{debug, debug_warn, error, implement, result::FlatOk, warn}; +use conduit::{debug, debug_error, debug_warn, error, implement, result::FlatOk, trace, warn}; use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, }; use serde_json::value::RawValue as RawJsonValue; +use tokio::time::{timeout_at, Instant}; use super::key_exists; @@ -136,8 +138,12 @@ async fn acquire_origins(&self, batch: I) -> Batch where I: Iterator)> + Send, { + let timeout = Instant::now() + .checked_add(Duration::from_secs(45)) + .expect("timeout overflows"); + let mut requests: FuturesUnordered<_> = batch - .map(|(origin, key_ids)| self.acquire_origin(origin, key_ids)) + .map(|(origin, key_ids)| self.acquire_origin(origin, key_ids, timeout)) .collect(); let mut missing = Batch::new(); @@ -152,11 +158,22 @@ where #[implement(super::Service)] async fn acquire_origin( - &self, origin: OwnedServerName, mut key_ids: Vec, + &self, origin: OwnedServerName, mut key_ids: Vec, timeout: Instant, ) -> (OwnedServerName, Vec) { - if let Ok(server_keys) = self.server_request(&origin).await { - self.add_signing_keys(server_keys.clone()).await; - key_ids.retain(|key_id| !key_exists(&server_keys, key_id)); + match timeout_at(timeout, self.server_request(&origin)).await { + Err(e) => debug_warn!(?origin, "timed out: {e}"), + Ok(Err(e)) => debug_error!(?origin, "{e}"), + Ok(Ok(server_keys)) => { + trace!( + %origin, + ?key_ids, + ?server_keys, + "received server_keys" + ); + + self.add_signing_keys(server_keys.clone()).await; + key_ids.retain(|key_id| !key_exists(&server_keys, key_id)); + }, } (origin, key_ids) From 2e4d9cb37cf7d47a9506ee3697775ddfadcb1d56 Mon Sep 17 00:00:00 2001 From: Kirill Hmelnitski Date: Thu, 31 Oct 2024 23:39:20 +0300 Subject: [PATCH 175/245] fix thread pagination refactor logic increase fetch limit for first relates apply other format Co-authored-by: Jason Volk Signed-off-by: Jason Volk --- src/api/client/relations.rs | 12 +-- src/service/rooms/pdu_metadata/data.rs | 41 +++---- src/service/rooms/pdu_metadata/mod.rs | 141 +++++++++++++------------ 3 files changed, 98 insertions(+), 96 deletions(-) diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index d43847300..0456924c2 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -20,8 +20,8 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( &body.event_id, body.event_type.clone().into(), body.rel_type.clone().into(), - body.from.as_ref(), - body.to.as_ref(), + body.from.as_deref(), + body.to.as_deref(), body.limit, body.recurse, body.dir, @@ -51,8 +51,8 @@ pub(crate) async fn get_relating_events_with_rel_type_route( &body.event_id, None, body.rel_type.clone().into(), - body.from.as_ref(), - body.to.as_ref(), + body.from.as_deref(), + body.to.as_deref(), body.limit, body.recurse, body.dir, @@ -82,8 +82,8 @@ pub(crate) async fn get_relating_events_route( &body.event_id, None, None, - body.from.as_ref(), - body.to.as_ref(), + body.from.as_deref(), + body.to.as_deref(), body.limit, body.recurse, body.dir, diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 4d570e6db..51a43714b 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -8,7 +8,7 @@ use conduit::{ }; use database::Map; use futures::{Stream, StreamExt}; -use ruma::{EventId, RoomId, UserId}; +use ruma::{api::Direction, EventId, RoomId, UserId}; use crate::{rooms, Dep}; @@ -45,9 +45,9 @@ impl Data { self.tofrom_relation.aput_raw::(key, []); } - pub(super) fn relations_until<'a>( - &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, - ) -> impl Stream + Send + 'a + '_ { + pub(super) fn get_relations<'a>( + &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, dir: Direction, + ) -> impl Stream + Send + '_ { let prefix = target.to_be_bytes().to_vec(); let mut current = prefix.clone(); let count_raw = match until { @@ -59,22 +59,23 @@ impl Data { }; current.extend_from_slice(&count_raw.to_be_bytes()); - self.tofrom_relation - .rev_raw_keys_from(¤t) - .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix)) - .map(|to_from| utils::u64_from_u8(&to_from[(size_of::())..])) - .filter_map(move |from| async move { - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); - let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?; - - if pdu.sender != user_id { - pdu.remove_transaction_id().log_err().ok(); - } - - Some((PduCount::Normal(from), pdu)) - }) + match dir { + Direction::Forward => self.tofrom_relation.raw_keys_from(¤t).boxed(), + Direction::Backward => self.tofrom_relation.rev_raw_keys_from(¤t).boxed(), + } + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|to_from| utils::u64_from_u8(&to_from[(size_of::())..])) + .filter_map(move |from| async move { + let mut pduid = shortroomid.to_be_bytes().to_vec(); + pduid.extend_from_slice(&from.to_be_bytes()); + let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?; + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + Some((PduCount::Normal(from), pdu)) + }) } pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index fb85d031b..b1cf2049c 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,12 +1,16 @@ mod data; use std::sync::Arc; -use conduit::{utils::stream::IterStream, PduCount, Result}; -use futures::StreamExt; +use conduit::{ + at, + utils::{result::FlatOk, stream::ReadyExt, IterStream}, + PduCount, Result, +}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, - uint, EventId, RoomId, UInt, UserId, + EventId, RoomId, UInt, UserId, }; use serde::Deserialize; @@ -63,24 +67,24 @@ impl Service { #[allow(clippy::too_many_arguments)] pub async fn paginate_relations_with_filter( &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option, - filter_rel_type: Option, from: Option<&String>, to: Option<&String>, limit: Option, + filter_rel_type: Option, from: Option<&str>, to: Option<&str>, limit: Option, recurse: bool, dir: Direction, ) -> Result { - let from = match from { - Some(from) => PduCount::try_from_string(from)?, - None => match dir { + let from = from + .map(PduCount::try_from_string) + .transpose()? + .unwrap_or_else(|| match dir { Direction::Forward => PduCount::min(), Direction::Backward => PduCount::max(), - }, - }; + }); - let to = to.and_then(|t| PduCount::try_from_string(t).ok()); + let to = to.map(PduCount::try_from_string).flat_ok(); - // Use limit or else 10, with maximum 100 - let limit = limit - .unwrap_or_else(|| uint!(10)) - .try_into() - .unwrap_or(10) + // Use limit or else 30, with maximum 100 + let limit: usize = limit + .map(TryInto::try_into) + .flat_ok() + .unwrap_or(30) .min(100); // Spec (v1.10) recommends depth of at least 3 @@ -90,68 +94,54 @@ impl Service { 1 }; - let relations_until: Vec = self - .relations_until(sender_user, room_id, target, from, depth) - .await?; - - // TODO: should be relations_after - let events: Vec<_> = relations_until + let events: Vec = self + .get_relations(sender_user, room_id, target, from, limit, depth, dir) + .await .into_iter() - .filter(move |(_, pdu): &PdusIterItem| { - if !filter_event_type.as_ref().map_or(true, |t| pdu.kind == *t) { - return false; - } - - let Ok(content) = pdu.get_content::() else { - return false; - }; - - filter_rel_type + .filter(|(_, pdu)| { + filter_event_type .as_ref() - .map_or(true, |r| *r == content.relates_to.rel_type) + .is_none_or(|kind| *kind == pdu.kind) + }) + .filter(|(_, pdu)| { + filter_rel_type.as_ref().is_none_or(|rel_type| { + pdu.get_content() + .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) + .is_ok_and(|r| r == *rel_type) + }) }) - .take(limit) - .take_while(|(k, _)| Some(*k) != to) .stream() .filter_map(|item| self.visibility_filter(sender_user, item)) + .ready_take_while(|(count, _)| Some(*count) != to) + .take(limit) .collect() + .boxed() .await; - let next_token = events.last().map(|(count, _)| count).copied(); - - let events_chunk: Vec<_> = match dir { - Direction::Forward => events - .into_iter() - .map(|(_, pdu)| pdu.to_message_like_event()) - .collect(), - Direction::Backward => events - .into_iter() - .rev() // relations are always most recent first - .map(|(_, pdu)| pdu.to_message_like_event()) - .collect(), - }; + let next_batch = match dir { + Direction::Backward => events.first(), + Direction::Forward => events.last(), + } + .map(at!(0)) + .map(|t| t.stringify()); Ok(get_relating_events::v1::Response { - chunk: events_chunk, - next_batch: next_token.map(|t| t.stringify()), + next_batch, prev_batch: Some(from.stringify()), recursion_depth: recurse.then_some(depth.into()), + chunk: events + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_message_like_event()) + .collect(), }) } - async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option { - let (_, pdu) = &item; - - self.services - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .await - .then_some(item) - } - - pub async fn relations_until( - &self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, max_depth: u8, - ) -> Result> { + #[allow(clippy::too_many_arguments)] + pub async fn get_relations( + &self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, limit: usize, max_depth: u8, + dir: Direction, + ) -> Vec { let room_id = self.services.short.get_or_create_shortroomid(room_id).await; let target = match self.services.timeline.get_pdu_count(target).await { @@ -160,24 +150,24 @@ impl Service { _ => 0, // This will result in an empty iterator }; - let mut pdus: Vec = self + let mut pdus: Vec<_> = self .db - .relations_until(user_id, room_id, target, until) + .get_relations(user_id, room_id, target, until, dir) .collect() .await; - let mut stack: Vec<_> = pdus.clone().into_iter().map(|pdu| (pdu, 1)).collect(); + let mut stack: Vec<_> = pdus.iter().map(|pdu| (pdu.clone(), 1)).collect(); - while let Some(stack_pdu) = stack.pop() { + 'limit: while let Some(stack_pdu) = stack.pop() { let target = match stack_pdu.0 .0 { PduCount::Normal(c) => c, // TODO: Support backfilled relations PduCount::Backfilled(_) => 0, // This will result in an empty iterator }; - let relations: Vec = self + let relations: Vec<_> = self .db - .relations_until(user_id, room_id, target, until) + .get_relations(user_id, room_id, target, until, dir) .collect() .await; @@ -187,12 +177,23 @@ impl Service { } pdus.push(relation); + if pdus.len() >= limit { + break 'limit; + } } } - pdus.sort_by(|a, b| a.0.cmp(&b.0)); + pdus + } + + async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option { + let (_, pdu) = &item; - Ok(pdus) + self.services + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) } #[inline] From 9da523c004aba6e9d1d51c73de0524d5f6433bbd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 2 Nov 2024 06:12:54 +0000 Subject: [PATCH 176/245] refactor for stronger RawPduId type implement standard traits for PduCount enable serde for arrayvec typedef various shortid's pducount simplifications split parts of pdu_metadata service to core/pdu and api/relations remove some yields; improve var names/syntax tweak types for limit timeline limit arguments Signed-off-by: Jason Volk --- Cargo.lock | 3 + Cargo.toml | 1 + src/api/client/context.rs | 4 +- src/api/client/membership.rs | 7 +- src/api/client/message.rs | 20 +-- src/api/client/relations.rs | 199 +++++++++++++++------- src/api/client/sync/mod.rs | 11 +- src/api/client/sync/v3.rs | 63 +++---- src/api/client/sync/v4.rs | 24 ++- src/api/client/threads.rs | 40 ++--- src/api/server/send_join.rs | 4 +- src/api/server/send_leave.rs | 6 +- src/core/mod.rs | 2 +- src/core/pdu/count.rs | 140 ++++++++++++--- src/core/pdu/id.rs | 22 +++ src/core/pdu/mod.rs | 6 + src/core/pdu/raw_id.rs | 117 +++++++++++++ src/core/pdu/relation.rs | 22 +++ src/core/pdu/tests.rs | 19 +++ src/service/migrations.rs | 2 +- src/service/rooms/auth_chain/data.rs | 8 +- src/service/rooms/auth_chain/mod.rs | 18 +- src/service/rooms/event_handler/mod.rs | 13 +- src/service/rooms/pdu_metadata/data.rs | 49 +++--- src/service/rooms/pdu_metadata/mod.rs | 116 +------------ src/service/rooms/search/mod.rs | 36 ++-- src/service/rooms/short/mod.rs | 8 +- src/service/rooms/spaces/mod.rs | 6 +- src/service/rooms/state/mod.rs | 4 +- src/service/rooms/state_accessor/data.rs | 6 +- src/service/rooms/state_accessor/mod.rs | 30 ++-- src/service/rooms/state_compressor/mod.rs | 42 +++-- src/service/rooms/threads/data.rs | 52 +++--- src/service/rooms/threads/mod.rs | 6 +- src/service/rooms/timeline/data.rs | 124 ++++++-------- src/service/rooms/timeline/mod.rs | 75 ++++---- src/service/rooms/timeline/pduid.rs | 13 -- src/service/rooms/user/mod.rs | 6 +- src/service/sending/data.rs | 12 +- src/service/sending/mod.rs | 23 +-- src/service/sending/sender.rs | 6 +- 41 files changed, 794 insertions(+), 571 deletions(-) create mode 100644 src/core/pdu/id.rs create mode 100644 src/core/pdu/raw_id.rs create mode 100644 src/core/pdu/relation.rs create mode 100644 src/core/pdu/tests.rs delete mode 100644 src/service/rooms/timeline/pduid.rs diff --git a/Cargo.lock b/Cargo.lock index 44856753f..f729d3d4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,9 @@ name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +dependencies = [ + "serde", +] [[package]] name = "as_variant" diff --git a/Cargo.toml b/Cargo.toml index 043790f8f..3ac1556c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ name = "conduit" [workspace.dependencies.arrayvec] version = "0.7.4" +features = ["std", "serde"] [workspace.dependencies.const-str] version = "0.5.7" diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 9bf0c4670..5b492cb19 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -168,12 +168,12 @@ pub(crate) async fn get_context_route( start: events_before .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()) + .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) .into(), end: events_after .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()) + .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) .into(), events_before: events_before diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index c41e93fa3..fa71c0c85 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1376,15 +1376,12 @@ pub(crate) async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id: Vec = services + let pdu_id = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value, true) .await? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept incoming PDU as timeline event."))))?; services.sending.send_pdu_room(room_id, &pdu_id).await?; return Ok(()); diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 4fc58d9f6..cb261a7f2 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -62,19 +62,17 @@ pub(crate) async fn get_message_events_route( let room_id = &body.room_id; let filter = &body.filter; - let from_default = match body.dir { - Direction::Forward => PduCount::min(), - Direction::Backward => PduCount::max(), - }; - - let from = body + let from: PduCount = body .from .as_deref() - .map(PduCount::try_from_string) + .map(str::parse) .transpose()? - .unwrap_or(from_default); + .unwrap_or_else(|| match body.dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), + }); - let to = body.to.as_deref().map(PduCount::try_from_string).flat_ok(); + let to: Option = body.to.as_deref().map(str::parse).flat_ok(); let limit: usize = body .limit @@ -156,8 +154,8 @@ pub(crate) async fn get_message_events_route( .collect(); Ok(get_message_events::v3::Response { - start: from.stringify(), - end: next_token.as_ref().map(PduCount::stringify), + start: from.to_string(), + end: next_token.as_ref().map(PduCount::to_string), chunk, state, }) diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 0456924c2..ef7035e2f 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,34 +1,43 @@ use axum::extract::State; -use ruma::api::client::relations::{ - get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, +use conduit::{ + at, + utils::{result::FlatOk, IterStream, ReadyExt}, + PduCount, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::{ + api::{ + client::relations::{ + get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, + }, + Direction, + }, + events::{relation::RelationType, TimelineEventType}, + EventId, RoomId, UInt, UserId, +}; +use service::{rooms::timeline::PdusIterItem, Services}; -use crate::{Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - - let res = services - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - body.event_type.clone().into(), - body.rel_type.clone().into(), - body.from.as_deref(), - body.to.as_deref(), - body.limit, - body.recurse, - body.dir, - ) - .await?; - - Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { + paginate_relations_with_filter( + &services, + body.sender_user(), + &body.room_id, + &body.event_id, + body.event_type.clone().into(), + body.rel_type.clone().into(), + body.from.as_deref(), + body.to.as_deref(), + body.limit, + body.recurse, + body.dir, + ) + .await + .map(|res| get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, @@ -40,26 +49,21 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - - let res = services - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - body.rel_type.clone().into(), - body.from.as_deref(), - body.to.as_deref(), - body.limit, - body.recurse, - body.dir, - ) - .await?; - - Ok(get_relating_events_with_rel_type::v1::Response { + paginate_relations_with_filter( + &services, + body.sender_user(), + &body.room_id, + &body.event_id, + None, + body.rel_type.clone().into(), + body.from.as_deref(), + body.to.as_deref(), + body.limit, + body.recurse, + body.dir, + ) + .await + .map(|res| get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, @@ -71,22 +75,103 @@ pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); + paginate_relations_with_filter( + &services, + body.sender_user(), + &body.room_id, + &body.event_id, + None, + None, + body.from.as_deref(), + body.to.as_deref(), + body.limit, + body.recurse, + body.dir, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +async fn paginate_relations_with_filter( + services: &Services, sender_user: &UserId, room_id: &RoomId, target: &EventId, + filter_event_type: Option, filter_rel_type: Option, from: Option<&str>, + to: Option<&str>, limit: Option, recurse: bool, dir: Direction, +) -> Result { + let from: PduCount = from + .map(str::parse) + .transpose()? + .unwrap_or_else(|| match dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), + }); - services + let to: Option = to.map(str::parse).flat_ok(); + + // Use limit or else 30, with maximum 100 + let limit: usize = limit + .map(TryInto::try_into) + .flat_ok() + .unwrap_or(30) + .min(100); + + // Spec (v1.10) recommends depth of at least 3 + let depth: u8 = if recurse { + 3 + } else { + 1 + }; + + let events: Vec = services .rooms .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - None, - body.from.as_deref(), - body.to.as_deref(), - body.limit, - body.recurse, - body.dir, - ) + .get_relations(sender_user, room_id, target, from, limit, depth, dir) + .await + .into_iter() + .filter(|(_, pdu)| { + filter_event_type + .as_ref() + .is_none_or(|kind| *kind == pdu.kind) + }) + .filter(|(_, pdu)| { + filter_rel_type + .as_ref() + .is_none_or(|rel_type| pdu.relation_type_equal(rel_type)) + }) + .stream() + .filter_map(|item| visibility_filter(services, sender_user, item)) + .ready_take_while(|(count, _)| Some(*count) != to) + .take(limit) + .collect() + .boxed() + .await; + + let next_batch = match dir { + Direction::Backward => events.first(), + Direction::Forward => events.last(), + } + .map(at!(0)) + .as_ref() + .map(ToString::to_string); + + Ok(get_relating_events::v1::Response { + next_batch, + prev_batch: Some(from.to_string()), + recursion_depth: recurse.then_some(depth.into()), + chunk: events + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_message_like_event()) + .collect(), + }) +} + +async fn visibility_filter(services: &Services, sender_user: &UserId, item: PdusIterItem) -> Option { + let (_, pdu) = &item; + + services + .rooms + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) .await + .then_some(item) } diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index ed22010c9..7aec7186f 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -1,10 +1,7 @@ mod v3; mod v4; -use conduit::{ - utils::{math::usize_from_u64_truncated, ReadyExt}, - PduCount, -}; +use conduit::{utils::ReadyExt, PduCount}; use futures::StreamExt; use ruma::{RoomId, UserId}; @@ -12,7 +9,7 @@ pub(crate) use self::{v3::sync_events_route, v4::sync_events_v4_route}; use crate::{service::Services, Error, PduEvent, Result}; async fn load_timeline( - services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, + services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: usize, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let last_timeline_count = services .rooms @@ -29,12 +26,12 @@ async fn load_timeline( .timeline .pdus_until(sender_user, room_id, PduCount::max()) .await? - .ready_take_while(|(pducount, _)| pducount > &roomsincecount); + .ready_take_while(|(pducount, _)| *pducount > roomsincecount); // Take the last events for the timeline let timeline_pdus: Vec<_> = non_timeline_pdus .by_ref() - .take(usize_from_u64_truncated(limit)) + .take(limit) .collect::>() .await .into_iter() diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index ccca1f85d..080489026 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -432,28 +432,26 @@ async fn handle_left_room( left_state_ids.insert(leave_shortstatekey, left_event_id); - let mut i: u8 = 0; - for (key, id) in left_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key).await?; + for (shortstatekey, event_id) in left_state_ids { + if full_state || since_state_ids.get(&shortstatekey) != Some(&event_id) { + let (event_type, state_key) = services + .rooms + .short + .get_statekey_from_short(shortstatekey) + .await?; + // TODO: Delete "element_hacks" when this is resolved: https://github.com/vector-im/element-web/issues/22565 if !lazy_load_enabled - || event_type != StateEventType::RoomMember - || full_state - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || (cfg!(feature = "element_hacks") && *sender_user == state_key) + || event_type != StateEventType::RoomMember + || full_state + || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); continue; }; left_state_events.push(pdu.to_sync_state_event()); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } } } } @@ -542,7 +540,7 @@ async fn load_joined_room( let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); - let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10).await?; + let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10_usize).await?; let send_notification_counts = !timeline_pdus.is_empty() || services @@ -678,8 +676,7 @@ async fn load_joined_room( let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); - let mut i: u8 = 0; - for (shortstatekey, id) in current_state_ids { + for (shortstatekey, event_id) in current_state_ids { let (event_type, state_key) = services .rooms .short @@ -687,24 +684,22 @@ async fn load_joined_room( .await?; if event_type != StateEventType::RoomMember { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); continue; }; + state_events.push(pdu); + continue; + } - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled - || full_state - || timeline_users.contains(&state_key) - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || (cfg!(feature = "element_hacks") && *sender_user == state_key) + // TODO: Delete "element_hacks" when this is resolved: https://github.com/vector-im/element-web/issues/22565 + if !lazy_load_enabled + || full_state || timeline_users.contains(&state_key) + || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); continue; }; @@ -712,12 +707,8 @@ async fn load_joined_room( if let Ok(uid) = UserId::parse(&state_key) { lazy_loaded.insert(uid); } - state_events.push(pdu); - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } + state_events.push(pdu); } } diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index f8ada81c9..11e3830cc 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -8,7 +8,7 @@ use axum::extract::State; use conduit::{ debug, error, extract_variant, utils::{ - math::{ruma_from_usize, usize_from_ruma}, + math::{ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, BoolExt, IterStream, ReadyExt, TryFutureExtExt, }, warn, Error, PduCount, Result, @@ -350,14 +350,16 @@ pub(crate) async fn sync_events_v4_route( new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); + let todo_room = + todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0_usize, u64::MAX)); - let limit = list + let limit: usize = list .room_details .timeline_limit - .map_or(10, u64::from) + .map(u64::from) + .map_or(10, usize_from_u64_truncated) .min(100); todo_room @@ -406,8 +408,14 @@ pub(crate) async fn sync_events_v4_route( } let todo_room = todo_rooms .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); - let limit = room.timeline_limit.map_or(10, u64::from).min(100); + .or_insert((BTreeSet::new(), 0_usize, u64::MAX)); + + let limit: usize = room + .timeline_limit + .map(u64::from) + .map_or(10, usize_from_u64_truncated) + .min(100); + todo_room.0.extend(room.required_state.iter().cloned()); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 50f6cdfb2..02cf79926 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,19 +1,14 @@ use axum::extract::State; -use conduit::PduEvent; +use conduit::{PduCount, PduEvent}; use futures::StreamExt; -use ruma::{ - api::client::{error::ErrorKind, threads::get_threads}, - uint, -}; +use ruma::{api::client::threads::get_threads, uint}; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/threads` pub(crate) async fn get_threads_route( - State(services): State, body: Ruma, + State(services): State, ref body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - // Use limit or else 10, with maximum 100 let limit = body .limit @@ -22,38 +17,39 @@ pub(crate) async fn get_threads_route( .unwrap_or(10) .min(100); - let from = if let Some(from) = &body.from { - from.parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))? - } else { - u64::MAX - }; + let from: PduCount = body + .from + .as_deref() + .map(str::parse) + .transpose()? + .unwrap_or_else(PduCount::max); - let room_id = &body.room_id; - let threads: Vec<(u64, PduEvent)> = services + let threads: Vec<(PduCount, PduEvent)> = services .rooms .threads - .threads_until(sender_user, &body.room_id, from, &body.include) + .threads_until(body.sender_user(), &body.room_id, from, &body.include) .await? .take(limit) .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) + .user_can_see_event(body.sender_user(), &body.room_id, &pdu.event_id) .await .then_some((count, pdu)) }) .collect() .await; - let next_batch = threads.last().map(|(count, _)| count.to_string()); - Ok(get_threads::v1::Response { + next_batch: threads + .last() + .map(|(count, _)| count) + .map(ToString::to_string), + chunk: threads .into_iter() .map(|(_, pdu)| pdu.to_room_event()) .collect(), - next_batch, }) } diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index c3273bafb..f2ede9d0a 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -156,12 +156,12 @@ async fn create_join_event( .lock(room_id) .await; - let pdu_id: Vec = services + let pdu_id = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true) .await? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; drop(mutex_lock); diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index 7b4a8aeef..448e5de34 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -1,7 +1,7 @@ #![allow(deprecated)] use axum::extract::State; -use conduit::{utils::ReadyExt, Error, Result}; +use conduit::{err, utils::ReadyExt, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -142,12 +142,12 @@ async fn create_leave_event( .lock(room_id) .await; - let pdu_id: Vec = services + let pdu_id = services .rooms .event_handler .handle_incoming_pdu(origin, room_id, &event_id, value, true) .await? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; drop(mutex_lock); diff --git a/src/core/mod.rs b/src/core/mod.rs index 1b7b8fa13..4ab847307 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -17,7 +17,7 @@ pub use ::tracing; pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; -pub use pdu::{Event, PduBuilder, PduCount, PduEvent}; +pub use pdu::{Event, PduBuilder, PduCount, PduEvent, PduId, RawPduId}; pub use server::Server; pub use utils::{ctor, dtor, implement, result, result::Result}; diff --git a/src/core/pdu/count.rs b/src/core/pdu/count.rs index 094988b69..90e552e89 100644 --- a/src/core/pdu/count.rs +++ b/src/core/pdu/count.rs @@ -1,51 +1,145 @@ -use std::cmp::Ordering; +#![allow(clippy::cast_possible_wrap, clippy::cast_sign_loss, clippy::as_conversions)] -use ruma::api::client::error::ErrorKind; +use std::{cmp::Ordering, fmt, fmt::Display, str::FromStr}; -use crate::{Error, Result}; +use crate::{err, Error, Result}; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] pub enum PduCount { - Backfilled(u64), Normal(u64), + Backfilled(i64), } impl PduCount { + #[inline] #[must_use] - pub fn min() -> Self { Self::Backfilled(u64::MAX) } + pub fn from_unsigned(unsigned: u64) -> Self { Self::from_signed(unsigned as i64) } + #[inline] #[must_use] - pub fn max() -> Self { Self::Normal(u64::MAX) } + pub fn from_signed(signed: i64) -> Self { + match signed { + i64::MIN..=0 => Self::Backfilled(signed), + _ => Self::Normal(signed as u64), + } + } + + #[inline] + #[must_use] + pub fn into_unsigned(self) -> u64 { + self.debug_assert_valid(); + match self { + Self::Normal(i) => i, + Self::Backfilled(i) => i as u64, + } + } + + #[inline] + #[must_use] + pub fn into_signed(self) -> i64 { + self.debug_assert_valid(); + match self { + Self::Normal(i) => i as i64, + Self::Backfilled(i) => i, + } + } + + #[inline] + #[must_use] + pub fn into_normal(self) -> Self { + self.debug_assert_valid(); + match self { + Self::Normal(i) => Self::Normal(i), + Self::Backfilled(_) => Self::Normal(0), + } + } + + #[inline] + pub fn checked_add(self, add: u64) -> Result { + Ok(match self { + Self::Normal(i) => Self::Normal( + i.checked_add(add) + .ok_or_else(|| err!(Arithmetic("PduCount::Normal overflow")))?, + ), + Self::Backfilled(i) => Self::Backfilled( + i.checked_add(add as i64) + .ok_or_else(|| err!(Arithmetic("PduCount::Backfilled overflow")))?, + ), + }) + } + + #[inline] + pub fn checked_sub(self, sub: u64) -> Result { + Ok(match self { + Self::Normal(i) => Self::Normal( + i.checked_sub(sub) + .ok_or_else(|| err!(Arithmetic("PduCount::Normal underflow")))?, + ), + Self::Backfilled(i) => Self::Backfilled( + i.checked_sub(sub as i64) + .ok_or_else(|| err!(Arithmetic("PduCount::Backfilled underflow")))?, + ), + }) + } - pub fn try_from_string(token: &str) -> Result { - if let Some(stripped_token) = token.strip_prefix('-') { - stripped_token.parse().map(PduCount::Backfilled) - } else { - token.parse().map(PduCount::Normal) + #[inline] + #[must_use] + pub fn saturating_add(self, add: u64) -> Self { + match self { + Self::Normal(i) => Self::Normal(i.saturating_add(add)), + Self::Backfilled(i) => Self::Backfilled(i.saturating_add(add as i64)), } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) } + #[inline] #[must_use] - pub fn stringify(&self) -> String { + pub fn saturating_sub(self, sub: u64) -> Self { match self { - Self::Backfilled(x) => format!("-{x}"), - Self::Normal(x) => x.to_string(), + Self::Normal(i) => Self::Normal(i.saturating_sub(sub)), + Self::Backfilled(i) => Self::Backfilled(i.saturating_sub(sub as i64)), + } + } + + #[inline] + #[must_use] + pub fn min() -> Self { Self::Backfilled(i64::MIN) } + + #[inline] + #[must_use] + pub fn max() -> Self { Self::Normal(i64::MAX as u64) } + + #[inline] + pub(crate) fn debug_assert_valid(&self) { + if let Self::Backfilled(i) = self { + debug_assert!(*i <= 0, "Backfilled sequence must be negative"); } } } +impl Display for PduCount { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + self.debug_assert_valid(); + match self { + Self::Normal(i) => write!(f, "{i}"), + Self::Backfilled(i) => write!(f, "{i}"), + } + } +} + +impl FromStr for PduCount { + type Err = Error; + + fn from_str(token: &str) -> Result { Ok(Self::from_signed(token.parse()?)) } +} + impl PartialOrd for PduCount { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for PduCount { - fn cmp(&self, other: &Self) -> Ordering { - match (self, other) { - (Self::Normal(s), Self::Normal(o)) => s.cmp(o), - (Self::Backfilled(s), Self::Backfilled(o)) => o.cmp(s), - (Self::Normal(_), Self::Backfilled(_)) => Ordering::Greater, - (Self::Backfilled(_), Self::Normal(_)) => Ordering::Less, - } - } + fn cmp(&self, other: &Self) -> Ordering { self.into_signed().cmp(&other.into_signed()) } +} + +impl Default for PduCount { + fn default() -> Self { Self::Normal(0) } } diff --git a/src/core/pdu/id.rs b/src/core/pdu/id.rs new file mode 100644 index 000000000..05d11904c --- /dev/null +++ b/src/core/pdu/id.rs @@ -0,0 +1,22 @@ +use super::{PduCount, RawPduId}; +use crate::utils::u64_from_u8x8; + +pub type ShortRoomId = ShortId; +pub type ShortEventId = ShortId; +pub type ShortId = u64; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct PduId { + pub shortroomid: ShortRoomId, + pub shorteventid: PduCount, +} + +impl From for PduId { + #[inline] + fn from(raw: RawPduId) -> Self { + Self { + shortroomid: u64_from_u8x8(raw.shortroomid()), + shorteventid: PduCount::from_unsigned(u64_from_u8x8(raw.shorteventid())), + } + } +} diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 53fcd0a95..c785c99ea 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -4,8 +4,12 @@ mod count; mod event; mod event_id; mod filter; +mod id; +mod raw_id; mod redact; +mod relation; mod strip; +mod tests; mod unsigned; use std::{cmp::Ordering, sync::Arc}; @@ -21,6 +25,8 @@ pub use self::{ count::PduCount, event::Event, event_id::*, + id::*, + raw_id::*, }; use crate::Result; diff --git a/src/core/pdu/raw_id.rs b/src/core/pdu/raw_id.rs new file mode 100644 index 000000000..faba1cbf1 --- /dev/null +++ b/src/core/pdu/raw_id.rs @@ -0,0 +1,117 @@ +use arrayvec::ArrayVec; + +use super::{PduCount, PduId, ShortEventId, ShortId, ShortRoomId}; + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum RawPduId { + Normal(RawPduIdNormal), + Backfilled(RawPduIdBackfilled), +} + +type RawPduIdNormal = [u8; RawPduId::NORMAL_LEN]; +type RawPduIdBackfilled = [u8; RawPduId::BACKFILLED_LEN]; + +const INT_LEN: usize = size_of::(); + +impl RawPduId { + const BACKFILLED_LEN: usize = size_of::() + INT_LEN + size_of::(); + const MAX_LEN: usize = Self::BACKFILLED_LEN; + const NORMAL_LEN: usize = size_of::() + size_of::(); + + #[inline] + #[must_use] + pub fn pdu_count(&self) -> PduCount { + let id: PduId = (*self).into(); + id.shorteventid + } + + #[inline] + #[must_use] + pub fn shortroomid(self) -> [u8; INT_LEN] { + match self { + Self::Normal(raw) => raw[0..INT_LEN] + .try_into() + .expect("normal raw shortroomid array from slice"), + Self::Backfilled(raw) => raw[0..INT_LEN] + .try_into() + .expect("backfilled raw shortroomid array from slice"), + } + } + + #[inline] + #[must_use] + pub fn shorteventid(self) -> [u8; INT_LEN] { + match self { + Self::Normal(raw) => raw[INT_LEN..INT_LEN * 2] + .try_into() + .expect("normal raw shorteventid array from slice"), + Self::Backfilled(raw) => raw[INT_LEN * 2..INT_LEN * 3] + .try_into() + .expect("backfilled raw shorteventid array from slice"), + } + } + + #[inline] + #[must_use] + pub fn as_bytes(&self) -> &[u8] { + match self { + Self::Normal(ref raw) => raw, + Self::Backfilled(ref raw) => raw, + } + } +} + +impl AsRef<[u8]> for RawPduId { + #[inline] + fn as_ref(&self) -> &[u8] { self.as_bytes() } +} + +impl From<&[u8]> for RawPduId { + #[inline] + fn from(id: &[u8]) -> Self { + match id.len() { + Self::NORMAL_LEN => Self::Normal( + id[0..Self::NORMAL_LEN] + .try_into() + .expect("normal RawPduId from [u8]"), + ), + Self::BACKFILLED_LEN => Self::Backfilled( + id[0..Self::BACKFILLED_LEN] + .try_into() + .expect("backfilled RawPduId from [u8]"), + ), + _ => unimplemented!("unrecognized RawPduId length"), + } + } +} + +impl From for RawPduId { + #[inline] + fn from(id: PduId) -> Self { + const MAX_LEN: usize = RawPduId::MAX_LEN; + type RawVec = ArrayVec; + + let mut vec = RawVec::new(); + vec.extend(id.shortroomid.to_be_bytes()); + id.shorteventid.debug_assert_valid(); + match id.shorteventid { + PduCount::Normal(shorteventid) => { + vec.extend(shorteventid.to_be_bytes()); + Self::Normal( + vec.as_ref() + .try_into() + .expect("RawVec into RawPduId::Normal"), + ) + }, + PduCount::Backfilled(shorteventid) => { + vec.extend(0_u64.to_be_bytes()); + vec.extend(shorteventid.to_be_bytes()); + Self::Backfilled( + vec.as_ref() + .try_into() + .expect("RawVec into RawPduId::Backfilled"), + ) + }, + } + } +} diff --git a/src/core/pdu/relation.rs b/src/core/pdu/relation.rs new file mode 100644 index 000000000..ae156a3de --- /dev/null +++ b/src/core/pdu/relation.rs @@ -0,0 +1,22 @@ +use ruma::events::relation::RelationType; +use serde::Deserialize; + +use crate::implement; + +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelType { + rel_type: RelationType, +} +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelatesToEventId { + #[serde(rename = "m.relates_to")] + relates_to: ExtractRelType, +} + +#[implement(super::PduEvent)] +#[must_use] +pub fn relation_type_equal(&self, rel_type: &RelationType) -> bool { + self.get_content() + .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) + .is_ok_and(|r| r == *rel_type) +} diff --git a/src/core/pdu/tests.rs b/src/core/pdu/tests.rs new file mode 100644 index 000000000..30ec23ba7 --- /dev/null +++ b/src/core/pdu/tests.rs @@ -0,0 +1,19 @@ +#![cfg(test)] + +use super::PduCount; + +#[test] +fn backfilled_parse() { + let count: PduCount = "-987654".parse().expect("parse() failed"); + let backfilled = matches!(count, PduCount::Backfilled(_)); + + assert!(backfilled, "not backfilled variant"); +} + +#[test] +fn normal_parse() { + let count: PduCount = "987654".parse().expect("parse() failed"); + let backfilled = matches!(count, PduCount::Backfilled(_)); + + assert!(!backfilled, "backfilled variant"); +} diff --git a/src/service/migrations.rs b/src/service/migrations.rs index 45323fa23..d6c342f86 100644 --- a/src/service/migrations.rs +++ b/src/service/migrations.rs @@ -71,7 +71,7 @@ async fn fresh(services: &Services) -> Result<()> { db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); // Create the admin room and server user on first run - crate::admin::create_admin_room(services).await?; + crate::admin::create_admin_room(services).boxed().await?; warn!( "Created new {} database with version {DATABASE_VERSION}", diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 5c9dbda83..3c36928af 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -7,9 +7,11 @@ use conduit::{err, utils, utils::math::usize_from_f64, Err, Result}; use database::Map; use lru_cache::LruCache; +use crate::rooms::short::ShortEventId; + pub(super) struct Data { shorteventid_authchain: Arc, - pub(super) auth_chain_cache: Mutex, Arc<[u64]>>>, + pub(super) auth_chain_cache: Mutex, Arc<[ShortEventId]>>>, } impl Data { @@ -24,7 +26,7 @@ impl Data { } } - pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); // Check RAM cache @@ -63,7 +65,7 @@ impl Data { Ok(chain) } - pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) { + pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[ShortEventId]>) { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); // Only persist single events in db diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 1387bc7d7..c22732c24 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -10,7 +10,7 @@ use futures::Stream; use ruma::{EventId, RoomId}; use self::data::Data; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortEventId, Dep}; pub struct Service { services: Services, @@ -64,7 +64,7 @@ impl Service { } #[tracing::instrument(skip_all, name = "auth_chain")] - pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { + pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); @@ -97,7 +97,7 @@ impl Service { continue; } - let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); + let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); @@ -156,7 +156,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); @@ -195,19 +195,19 @@ impl Service { } #[inline] - pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) { - let val = auth_chain.iter().copied().collect::>(); + pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) { + let val = auth_chain.iter().copied().collect::>(); self.db.cache_auth_chain(key, val); } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) { - let val = auth_chain.iter().copied().collect::>(); + pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) { + let val = auth_chain.iter().copied().collect::>(); self.db.cache_auth_chain(key, val); } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index adebd3323..f76f817d3 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -35,7 +35,10 @@ use ruma::{ use crate::{ globals, rooms, - rooms::state_compressor::{CompressedStateEvent, HashSetCompressStateEvent}, + rooms::{ + state_compressor::{CompressedStateEvent, HashSetCompressStateEvent}, + timeline::RawPduId, + }, sending, server_keys, Dep, }; @@ -136,10 +139,10 @@ impl Service { pub async fn handle_incoming_pdu<'a>( &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, value: BTreeMap, is_timeline_event: bool, - ) -> Result>> { + ) -> Result> { // 1. Skip the PDU if we already have it as a timeline event if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { - return Ok(Some(pdu_id.to_vec())); + return Ok(Some(pdu_id)); } // 1.1 Check the server is in the room @@ -488,7 +491,7 @@ impl Service { pub async fn upgrade_outlier_to_timeline_pdu( &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, origin: &ServerName, room_id: &RoomId, - ) -> Result>> { + ) -> Result> { // Skip the PDU if we already have it as a timeline event if let Ok(pduid) = self .services @@ -496,7 +499,7 @@ impl Service { .get_pdu_id(&incoming_pdu.event_id) .await { - return Ok(Some(pduid.to_vec())); + return Ok(Some(pduid)); } if self diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 51a43714b..3fc065915 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -2,15 +2,21 @@ use std::{mem::size_of, sync::Arc}; use conduit::{ result::LogErr, - utils, - utils::{stream::TryIgnore, ReadyExt}, + utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, PduCount, PduEvent, }; use database::Map; use futures::{Stream, StreamExt}; use ruma::{api::Direction, EventId, RoomId, UserId}; -use crate::{rooms, Dep}; +use crate::{ + rooms, + rooms::{ + short::{ShortEventId, ShortRoomId}, + timeline::{PduId, RawPduId}, + }, + Dep, +}; pub(super) struct Data { tofrom_relation: Arc, @@ -46,35 +52,36 @@ impl Data { } pub(super) fn get_relations<'a>( - &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, dir: Direction, + &'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { - let prefix = target.to_be_bytes().to_vec(); - let mut current = prefix.clone(); - let count_raw = match until { - PduCount::Normal(x) => x.saturating_sub(1), - PduCount::Backfilled(x) => { - current.extend_from_slice(&0_u64.to_be_bytes()); - u64::MAX.saturating_sub(x).saturating_sub(1) - }, - }; - current.extend_from_slice(&count_raw.to_be_bytes()); + let current: RawPduId = PduId { + shortroomid, + shorteventid: from, + } + .into(); match dir { Direction::Forward => self.tofrom_relation.raw_keys_from(¤t).boxed(), Direction::Backward => self.tofrom_relation.rev_raw_keys_from(¤t).boxed(), } .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix)) - .map(|to_from| utils::u64_from_u8(&to_from[(size_of::())..])) - .filter_map(move |from| async move { - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); - let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?; + .ready_take_while(move |key| key.starts_with(&target.to_be_bytes())) + .map(|to_from| u64_from_u8(&to_from[8..16])) + .map(PduCount::from_unsigned) + .filter_map(move |shorteventid| async move { + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid, + } + .into(); + + let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; + if pdu.sender != user_id { pdu.remove_transaction_id().log_err().ok(); } - Some((PduCount::Normal(from), pdu)) + Some((shorteventid, pdu)) }) } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index b1cf2049c..82d2ee35b 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,18 +1,9 @@ mod data; use std::sync::Arc; -use conduit::{ - at, - utils::{result::FlatOk, stream::ReadyExt, IterStream}, - PduCount, Result, -}; -use futures::{FutureExt, StreamExt}; -use ruma::{ - api::{client::relations::get_relating_events, Direction}, - events::{relation::RelationType, TimelineEventType}, - EventId, RoomId, UInt, UserId, -}; -use serde::Deserialize; +use conduit::{PduCount, Result}; +use futures::StreamExt; +use ruma::{api::Direction, EventId, RoomId, UserId}; use self::data::{Data, PdusIterItem}; use crate::{rooms, Dep}; @@ -24,26 +15,14 @@ pub struct Service { struct Services { short: Dep, - state_accessor: Dep, timeline: Dep, } -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelType { - rel_type: RelationType, -} -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelatesToEventId { - #[serde(rename = "m.relates_to")] - relates_to: ExtractRelType, -} - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { services: Services { short: args.depend::("rooms::short"), - state_accessor: args.depend::("rooms::state_accessor"), timeline: args.depend::("rooms::timeline"), }, db: Data::new(&args), @@ -64,82 +43,9 @@ impl Service { } } - #[allow(clippy::too_many_arguments)] - pub async fn paginate_relations_with_filter( - &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option, - filter_rel_type: Option, from: Option<&str>, to: Option<&str>, limit: Option, - recurse: bool, dir: Direction, - ) -> Result { - let from = from - .map(PduCount::try_from_string) - .transpose()? - .unwrap_or_else(|| match dir { - Direction::Forward => PduCount::min(), - Direction::Backward => PduCount::max(), - }); - - let to = to.map(PduCount::try_from_string).flat_ok(); - - // Use limit or else 30, with maximum 100 - let limit: usize = limit - .map(TryInto::try_into) - .flat_ok() - .unwrap_or(30) - .min(100); - - // Spec (v1.10) recommends depth of at least 3 - let depth: u8 = if recurse { - 3 - } else { - 1 - }; - - let events: Vec = self - .get_relations(sender_user, room_id, target, from, limit, depth, dir) - .await - .into_iter() - .filter(|(_, pdu)| { - filter_event_type - .as_ref() - .is_none_or(|kind| *kind == pdu.kind) - }) - .filter(|(_, pdu)| { - filter_rel_type.as_ref().is_none_or(|rel_type| { - pdu.get_content() - .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) - .is_ok_and(|r| r == *rel_type) - }) - }) - .stream() - .filter_map(|item| self.visibility_filter(sender_user, item)) - .ready_take_while(|(count, _)| Some(*count) != to) - .take(limit) - .collect() - .boxed() - .await; - - let next_batch = match dir { - Direction::Backward => events.first(), - Direction::Forward => events.last(), - } - .map(at!(0)) - .map(|t| t.stringify()); - - Ok(get_relating_events::v1::Response { - next_batch, - prev_batch: Some(from.stringify()), - recursion_depth: recurse.then_some(depth.into()), - chunk: events - .into_iter() - .map(at!(1)) - .map(|pdu| pdu.to_message_like_event()) - .collect(), - }) - } - #[allow(clippy::too_many_arguments)] pub async fn get_relations( - &self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, limit: usize, max_depth: u8, + &self, user_id: &UserId, room_id: &RoomId, target: &EventId, from: PduCount, limit: usize, max_depth: u8, dir: Direction, ) -> Vec { let room_id = self.services.short.get_or_create_shortroomid(room_id).await; @@ -152,7 +58,7 @@ impl Service { let mut pdus: Vec<_> = self .db - .get_relations(user_id, room_id, target, until, dir) + .get_relations(user_id, room_id, target, from, dir) .collect() .await; @@ -167,7 +73,7 @@ impl Service { let relations: Vec<_> = self .db - .get_relations(user_id, room_id, target, until, dir) + .get_relations(user_id, room_id, target, from, dir) .collect() .await; @@ -186,16 +92,6 @@ impl Service { pdus } - async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option { - let (_, pdu) = &item; - - self.services - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .await - .then_some(item) - } - #[inline] #[tracing::instrument(skip_all, level = "debug")] pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 70daded1e..1af37d9e5 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,10 +1,10 @@ -use std::{iter, sync::Arc}; +use std::sync::Arc; use arrayvec::ArrayVec; use conduit::{ implement, utils::{set, stream::TryIgnore, ArrayVecExt, IterStream, ReadyExt}, - PduEvent, Result, + PduCount, PduEvent, Result, }; use database::{keyval::Val, Map}; use futures::{Stream, StreamExt}; @@ -66,13 +66,13 @@ impl crate::Service for Service { } #[implement(Service)] -pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { +pub fn index_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { let batch = tokenize(message_body) .map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here (key, Vec::::new()) }) .collect::>(); @@ -81,12 +81,12 @@ pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { } #[implement(Service)] -pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { +pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { let batch = tokenize(message_body).map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here key }); @@ -159,24 +159,24 @@ fn search_pdu_ids_query_words<'a>( &'a self, shortroomid: ShortRoomId, word: &'a str, ) -> impl Stream + Send + '_ { self.search_pdu_ids_query_word(shortroomid, word) - .ready_filter_map(move |key| { - key[prefix_len(word)..] - .chunks_exact(PduId::LEN) - .next() - .map(RawPduId::try_from) - .and_then(Result::ok) + .map(move |key| -> RawPduId { + let key = &key[prefix_len(word)..]; + key.into() }) } /// Iterate over raw database results for a word #[implement(Service)] fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream> + Send + '_ { - const PDUID_LEN: usize = PduId::LEN; // rustc says const'ing this not yet stable - let end_id: ArrayVec = iter::repeat(u8::MAX).take(PduId::LEN).collect(); + let end_id: RawPduId = PduId { + shortroomid, + shorteventid: PduCount::max(), + } + .into(); // Newest pdus first - let end = make_tokenid(shortroomid, word, end_id.as_slice()); + let end = make_tokenid(shortroomid, word, &end_id); let prefix = make_prefix(shortroomid, word); self.db .tokenids @@ -196,11 +196,9 @@ fn tokenize(body: &str) -> impl Iterator + Send + '_ { .map(str::to_lowercase) } -fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &[u8]) -> TokenId { - debug_assert!(pdu_id.len() == PduId::LEN, "pdu_id size mismatch"); - +fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId { let mut key = make_prefix(shortroomid, word); - key.extend_from_slice(pdu_id); + key.extend_from_slice(pdu_id.as_ref()); key } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index a903ef22a..9fddf099e 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,5 +1,6 @@ use std::{mem::size_of_val, sync::Arc}; +pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId}; use conduit::{err, implement, utils, Result}; use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; @@ -26,9 +27,6 @@ struct Services { pub type ShortStateHash = ShortId; pub type ShortStateKey = ShortId; -pub type ShortEventId = ShortId; -pub type ShortRoomId = ShortId; -pub type ShortId = u64; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -52,7 +50,7 @@ impl crate::Service for Service { #[implement(Service)] pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId { - const BUFSIZE: usize = size_of::(); + const BUFSIZE: usize = size_of::(); if let Ok(shorteventid) = self .db @@ -88,7 +86,7 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> .map(|(i, result)| match result { Ok(ref short) => utils::u64_from_u8(short), Err(_) => { - const BUFSIZE: usize = size_of::(); + const BUFSIZE: usize = size_of::(); let short = self.services.globals.next_count().unwrap(); debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 5aea5f6a0..37272dca8 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -33,7 +33,7 @@ use ruma::{ }; use tokio::sync::Mutex; -use crate::{rooms, sending, Dep}; +use crate::{rooms, rooms::short::ShortRoomId, sending, Dep}; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -49,7 +49,7 @@ pub enum SummaryAccessibility { pub struct PaginationToken { /// Path down the hierarchy of the room to start the response at, /// excluding the root space. - pub short_room_ids: Vec, + pub short_room_ids: Vec, pub limit: UInt, pub max_depth: UInt, pub suggested_only: bool, @@ -448,7 +448,7 @@ impl Service { } pub async fn get_client_hierarchy( - &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec, max_depth: u64, + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec, max_depth: u64, suggested_only: bool, ) -> Result { let mut parents = VecDeque::new(); diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 34fab0798..71a3900cd 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -95,7 +95,7 @@ impl Service { let event_ids = statediffnew.iter().stream().filter_map(|new| { self.services .state_compressor - .parse_compressed_state_event(new) + .parse_compressed_state_event(*new) .map_ok_or_else(|_| None, |(_, event_id)| Some(event_id)) }); @@ -428,7 +428,7 @@ impl Service { let Ok((shortstatekey, event_id)) = self .services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .await else { continue; diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 9c96785f4..06cd648cf 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -53,7 +53,7 @@ impl Data { let parsed = self .services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .await?; result.insert(parsed.0, parsed.1); @@ -86,7 +86,7 @@ impl Data { let (_, eventid) = self .services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .await?; if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await { @@ -132,7 +132,7 @@ impl Data { self.services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .map_ok(|(_, id)| id) .map_err(|e| { err!(Database(error!( diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index a2cc27e85..d51da8af9 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -39,13 +39,17 @@ use ruma::{ use serde::Deserialize; use self::data::Data; -use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; +use crate::{ + rooms, + rooms::{short::ShortStateHash, state::RoomMutexGuard}, + Dep, +}; pub struct Service { services: Services, db: Data, - pub server_visibility_cache: Mutex>, - pub user_visibility_cache: Mutex>, + pub server_visibility_cache: Mutex>, + pub user_visibility_cache: Mutex>, } struct Services { @@ -94,11 +98,13 @@ impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result>> { self.db.state_full_ids(shortstatehash).await } - pub async fn state_full(&self, shortstatehash: u64) -> Result>> { + pub async fn state_full( + &self, shortstatehash: ShortStateHash, + ) -> Result>> { self.db.state_full(shortstatehash).await } @@ -106,7 +112,7 @@ impl Service { /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] pub async fn state_get_id( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result> { self.db .state_get_id(shortstatehash, event_type, state_key) @@ -117,7 +123,7 @@ impl Service { /// `state_key`). #[inline] pub async fn state_get( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result> { self.db .state_get(shortstatehash, event_type, state_key) @@ -126,7 +132,7 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). pub async fn state_get_content( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result where T: for<'de> Deserialize<'de> + Send, @@ -137,7 +143,7 @@ impl Service { } /// Get membership for given user in state - async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState { + async fn user_membership(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> MembershipState { self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) .await .map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership) @@ -145,14 +151,14 @@ impl Service { /// The user was a joined member at this state (potentially in the past) #[inline] - async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { + async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { self.user_membership(shortstatehash, user_id).await == MembershipState::Join } /// The user was an invited or joined room member at this state (potentially /// in the past) #[inline] - async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { + async fn user_was_invited(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { let s = self.user_membership(shortstatehash, user_id).await; s == MembershipState::Join || s == MembershipState::Invite } @@ -285,7 +291,7 @@ impl Service { } /// Returns the state hash for this pdu. - pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { + pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { self.db.pdu_shortstatehash(event_id).await } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index e213490ba..bf90d5c4d 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -34,25 +34,26 @@ struct Data { #[derive(Clone)] struct StateDiff { parent: Option, - added: Arc>, - removed: Arc>, + added: Arc, + removed: Arc, } #[derive(Clone, Default)] pub struct ShortStateInfo { pub shortstatehash: ShortStateHash, - pub full_state: Arc>, - pub added: Arc>, - pub removed: Arc>, + pub full_state: Arc, + pub added: Arc, + pub removed: Arc, } #[derive(Clone, Default)] pub struct HashSetCompressStateEvent { pub shortstatehash: ShortStateHash, - pub added: Arc>, - pub removed: Arc>, + pub added: Arc, + pub removed: Arc, } +pub(crate) type CompressedState = HashSet; pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; type StateInfoLruCache = LruCache; type ShortStateInfoVec = Vec; @@ -105,7 +106,7 @@ impl Service { removed, } = self.get_statediff(shortstatehash).await?; - if let Some(parent) = parent { + let response = if let Some(parent) = parent { let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; let mut state = (*response.last().expect("at least one response").full_state).clone(); state.extend(added.iter().copied()); @@ -121,27 +122,22 @@ impl Service { removed: Arc::new(removed), }); - self.stateinfo_cache - .lock() - .expect("locked") - .insert(shortstatehash, response.clone()); - - Ok(response) + response } else { - let response = vec![ShortStateInfo { + vec![ShortStateInfo { shortstatehash, full_state: added.clone(), added, removed, - }]; + }] + }; - self.stateinfo_cache - .lock() - .expect("locked") - .insert(shortstatehash, response.clone()); + self.stateinfo_cache + .lock() + .expect("locked") + .insert(shortstatehash, response.clone()); - Ok(response) - } + Ok(response) } pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent { @@ -161,7 +157,7 @@ impl Service { /// Returns shortstatekey, event id #[inline] pub async fn parse_compressed_state_event( - &self, compressed_event: &CompressedStateEvent, + &self, compressed_event: CompressedStateEvent, ) -> Result<(ShortStateKey, Arc)> { use utils::u64_from_u8; diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index f50b812ca..c26dabb40 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,17 +1,22 @@ -use std::{mem::size_of, sync::Arc}; +use std::sync::Arc; use conduit::{ - checked, result::LogErr, - utils, utils::{stream::TryIgnore, ReadyExt}, - PduEvent, Result, + PduCount, PduEvent, Result, }; use database::{Deserialized, Map}; use futures::{Stream, StreamExt}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; -use crate::{rooms, Dep}; +use crate::{ + rooms, + rooms::{ + short::ShortRoomId, + timeline::{PduId, RawPduId}, + }, + Dep, +}; pub(super) struct Data { threadid_userids: Arc, @@ -35,40 +40,39 @@ impl Data { } } + #[inline] pub(super) async fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> Result + Send + 'a> { - let prefix = self - .services - .short - .get_shortroomid(room_id) - .await? - .to_be_bytes() - .to_vec(); + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, _include: &'a IncludeThreads, + ) -> Result + Send + 'a> { + let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?; - let mut current = prefix.clone(); - current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); + let current: RawPduId = PduId { + shortroomid, + shorteventid: until.saturating_sub(1), + } + .into(); let stream = self .threadid_userids .rev_raw_keys_from(¤t) .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix)) - .map(|pduid| (utils::u64_from_u8(&pduid[(size_of::())..]), pduid)) - .filter_map(move |(count, pduid)| async move { - let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?; + .map(RawPduId::from) + .ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes()) + .filter_map(move |pdu_id| async move { + let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; + let pdu_id: PduId = pdu_id.into(); if pdu.sender != user_id { pdu.remove_transaction_id().log_err().ok(); } - Some((count, pdu)) + Some((pdu_id.shorteventid, pdu)) }); Ok(stream) } - pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { + pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result { let users = participants .iter() .map(|user| user.as_bytes()) @@ -80,7 +84,7 @@ impl Data { Ok(()) } - pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result> { - self.threadid_userids.qry(root_id).await.deserialized() + pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result> { + self.threadid_userids.get(root_id).await.deserialized() } } diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index 2eafe5d52..025030307 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,7 +2,7 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{err, PduEvent, Result}; +use conduit::{err, PduCount, PduEvent, Result}; use data::Data; use futures::Stream; use ruma::{ @@ -37,8 +37,8 @@ impl crate::Service for Service { impl Service { pub async fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> Result + Send + 'a> { + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, include: &'a IncludeThreads, + ) -> Result + Send + 'a> { self.db .threads_until(user_id, room_id, until, include) .await diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 5428a3b9d..19dc5325a 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,14 +1,13 @@ use std::{ collections::{hash_map, HashMap}, - mem::size_of, sync::Arc, }; use conduit::{ - err, expected, + at, err, result::{LogErr, NotFound}, utils, - utils::{future::TryExtExt, stream::TryIgnore, u64_from_u8, ReadyExt}, + utils::{future::TryExtExt, stream::TryIgnore, ReadyExt}, Err, PduCount, PduEvent, Result, }; use database::{Database, Deserialized, Json, KeyVal, Map}; @@ -16,7 +15,8 @@ use futures::{Stream, StreamExt}; use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; -use crate::{rooms, Dep}; +use super::{PduId, RawPduId}; +use crate::{rooms, rooms::short::ShortRoomId, Dep}; pub(super) struct Data { eventid_outlierpdu: Arc, @@ -58,30 +58,25 @@ impl Data { .lasttimelinecount_cache .lock() .await - .entry(room_id.to_owned()) + .entry(room_id.into()) { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max()) - .await? - .next() - .await - { - Ok(*v.insert(last_count.0)) - } else { - Ok(PduCount::Normal(0)) - } - }, hash_map::Entry::Occupied(o) => Ok(*o.get()), + hash_map::Entry::Vacant(v) => Ok(self + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .next() + .await + .map(at!(0)) + .filter(|&count| matches!(count, PduCount::Normal(_))) + .map_or_else(PduCount::max, |count| *v.insert(count))), } } /// Returns the `count` of this pdu's id. pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result { - self.eventid_pduid - .get(event_id) + self.get_pdu_id(event_id) .await - .map(|pdu_id| pdu_count(&pdu_id)) + .map(|pdu_id| pdu_id.pdu_count()) } /// Returns the json of a pdu. @@ -102,8 +97,11 @@ impl Data { /// Returns the pdu's id. #[inline] - pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result> { - self.eventid_pduid.get(event_id).await + pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result { + self.eventid_pduid + .get(event_id) + .await + .map(|handle| RawPduId::from(&*handle)) } /// Returns the pdu directly from `eventid_pduid` only. @@ -154,34 +152,40 @@ impl Data { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { + pub(super) async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result { self.pduid_pdu.get(pdu_id).await.deserialized() } /// Returns the pdu as a `BTreeMap`. - pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result { self.pduid_pdu.get(pdu_id).await.deserialized() } - pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { + pub(super) async fn append_pdu( + &self, pdu_id: &RawPduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: PduCount, + ) { + debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal"); + self.pduid_pdu.raw_put(pdu_id, Json(json)); self.lasttimelinecount_cache .lock() .await - .insert(pdu.room_id.clone(), PduCount::Normal(count)); + .insert(pdu.room_id.clone(), count); self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id); self.eventid_outlierpdu.remove(pdu.event_id.as_bytes()); } - pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { + pub(super) fn prepend_backfill_pdu(&self, pdu_id: &RawPduId, event_id: &EventId, json: &CanonicalJsonObject) { self.pduid_pdu.raw_put(pdu_id, Json(json)); self.eventid_pduid.insert(event_id, pdu_id); self.eventid_outlierpdu.remove(event_id); } /// Removes a pdu and creates a new one with the same id. - pub(super) async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result { + pub(super) async fn replace_pdu( + &self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, + ) -> Result { if self.pduid_pdu.get(pdu_id).await.is_not_found() { return Err!(Request(NotFound("PDU does not exist."))); } @@ -197,13 +201,14 @@ impl Data { pub(super) async fn pdus_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { - let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?; + let current = self.count_to_id(room_id, until, true).await?; + let prefix = current.shortroomid(); let stream = self .pduid_pdu .rev_raw_stream_from(¤t) .ignore_err() .ready_take_while(move |(key, _)| key.starts_with(&prefix)) - .map(move |item| Self::each_pdu(item, user_id)); + .map(|item| Self::each_pdu(item, user_id)); Ok(stream) } @@ -211,7 +216,8 @@ impl Data { pub(super) async fn pdus_after<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { - let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?; + let current = self.count_to_id(room_id, from, false).await?; + let prefix = current.shortroomid(); let stream = self .pduid_pdu .raw_stream_from(¤t) @@ -223,6 +229,8 @@ impl Data { } fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem { + let pdu_id: RawPduId = pdu_id.into(); + let mut pdu = serde_json::from_slice::(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON"); @@ -231,9 +239,8 @@ impl Data { } pdu.add_age().log_err().ok(); - let count = pdu_count(pdu_id); - (count, pdu) + (pdu_id.pdu_count(), pdu) } pub(super) fn increment_notification_counts( @@ -256,56 +263,25 @@ impl Data { } } - pub(super) async fn count_to_id( - &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, - ) -> Result<(Vec, Vec)> { - let prefix = self + async fn count_to_id(&self, room_id: &RoomId, count: PduCount, subtract: bool) -> Result { + let shortroomid: ShortRoomId = self .services .short .get_shortroomid(room_id) .await - .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))? - .to_be_bytes() - .to_vec(); + .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?; - let mut pdu_id = prefix.clone(); // +1 so we don't send the base event - let count_raw = match count { - PduCount::Normal(x) => { - if subtract { - x.saturating_sub(offset) - } else { - x.saturating_add(offset) - } - }, - PduCount::Backfilled(x) => { - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - let num = u64::MAX.saturating_sub(x); - if subtract { - num.saturating_sub(offset) - } else { - num.saturating_add(offset) - } + let pdu_id = PduId { + shortroomid, + shorteventid: if subtract { + count.checked_sub(1)? + } else { + count.checked_add(1)? }, }; - pdu_id.extend_from_slice(&count_raw.to_be_bytes()); - - Ok((prefix, pdu_id)) - } -} - -/// Returns the `count` of this pdu's id. -pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount { - const STRIDE: usize = size_of::(); - - let pdu_id_len = pdu_id.len(); - let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]); - let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]); - if second_last_u64 == 0 { - PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)) - } else { - PduCount::Normal(last_u64) + Ok(pdu_id.into()) } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index e45bf7e52..86a479195 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,5 +1,4 @@ mod data; -mod pduid; use std::{ cmp, @@ -15,6 +14,7 @@ use conduit::{ utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, validated, warn, Err, Error, Result, Server, }; +pub use conduit::{PduId, RawPduId}; use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryStreamExt}; use ruma::{ api::federation, @@ -39,13 +39,13 @@ use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use self::data::Data; -pub use self::{ - data::PdusIterItem, - pduid::{PduId, RawPduId}, -}; +pub use self::data::PdusIterItem; use crate::{ - account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, - rooms::state_compressor::CompressedStateEvent, sending, server_keys, users, Dep, + account_data, admin, appservice, + appservice::NamespaceRegex, + globals, pusher, rooms, + rooms::{short::ShortRoomId, state_compressor::CompressedStateEvent}, + sending, server_keys, users, Dep, }; // Update Relationships @@ -229,9 +229,7 @@ impl Service { /// Returns the pdu's id. #[inline] - pub async fn get_pdu_id(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_id(event_id).await - } + pub async fn get_pdu_id(&self, event_id: &EventId) -> Result { self.db.get_pdu_id(event_id).await } /// Returns the pdu. /// @@ -256,16 +254,16 @@ impl Service { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { self.db.get_pdu_from_id(pdu_id).await } + pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result { self.db.get_pdu_from_id(pdu_id).await } /// Returns the pdu as a `BTreeMap`. - pub async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result { self.db.get_pdu_json_from_id(pdu_id).await } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self), level = "debug")] - pub async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { self.db.replace_pdu(pdu_id, pdu_json, pdu).await } @@ -282,7 +280,7 @@ impl Service { mut pdu_json: CanonicalJsonObject, leaves: Vec, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result> { + ) -> Result { // Coalesce database writes for the remainder of this scope. let _cork = self.db.db.cork_and_flush(); @@ -359,9 +357,12 @@ impl Service { .user .reset_notification_counts(&pdu.sender, &pdu.room_id); - let count2 = self.services.globals.next_count().unwrap(); - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&count2.to_be_bytes()); + let count2 = PduCount::Normal(self.services.globals.next_count().unwrap()); + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid: count2, + } + .into(); // Insert pdu self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await; @@ -544,7 +545,7 @@ impl Service { if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount); + .add_relation(count2, related_pducount); } } @@ -558,7 +559,7 @@ impl Service { if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount); + .add_relation(count2, related_pducount); } }, Relation::Thread(thread) => { @@ -580,7 +581,7 @@ impl Service { { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; continue; } @@ -596,7 +597,7 @@ impl Service { if state_key_uid == appservice_uid { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; continue; } } @@ -623,7 +624,7 @@ impl Service { { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; } } @@ -935,7 +936,7 @@ impl Service { state_ids_compressed: Arc>, soft_fail: bool, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result>> { + ) -> Result> { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. @@ -993,7 +994,7 @@ impl Service { /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { + pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: ShortRoomId) -> Result { // TODO: Don't reserialize, keep original json let Ok(pdu_id) = self.get_pdu_id(event_id).await else { // If event does not exist, just noop @@ -1133,7 +1134,6 @@ impl Service { // Skip the PDU if we already have it as a timeline event if let Ok(pdu_id) = self.get_pdu_id(&event_id).await { - let pdu_id = pdu_id.to_vec(); debug!("We already know {event_id} at {pdu_id:?}"); return Ok(()); } @@ -1158,11 +1158,13 @@ impl Service { let insert_lock = self.mutex_insert.lock(&room_id).await; - let max = u64::MAX; - let count = self.services.globals.next_count().unwrap(); - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(validated!(max - count)).to_be_bytes()); + let count: i64 = self.services.globals.next_count().unwrap().try_into()?; + + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid: PduCount::Backfilled(validated!(0 - count)), + } + .into(); // Insert pdu self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value); @@ -1246,16 +1248,3 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn comparisons() { - assert!(PduCount::Normal(1) < PduCount::Normal(2)); - assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1)); - assert!(PduCount::Normal(1) > PduCount::Backfilled(1)); - assert!(PduCount::Backfilled(1) < PduCount::Normal(1)); - } -} diff --git a/src/service/rooms/timeline/pduid.rs b/src/service/rooms/timeline/pduid.rs deleted file mode 100644 index b43c382cf..000000000 --- a/src/service/rooms/timeline/pduid.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::rooms::short::{ShortEventId, ShortRoomId}; - -#[derive(Clone, Copy)] -pub struct PduId { - _room_id: ShortRoomId, - _event_id: ShortEventId, -} - -pub type RawPduId = [u8; PduId::LEN]; - -impl PduId { - pub const LEN: usize = size_of::() + size_of::(); -} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index e484203d5..995871342 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -5,7 +5,7 @@ use database::{Deserialized, Map}; use futures::{pin_mut, Stream, StreamExt}; use ruma::{RoomId, UserId}; -use crate::{globals, rooms, Dep}; +use crate::{globals, rooms, rooms::short::ShortStateHash, Dep}; pub struct Service { db: Data, @@ -93,7 +93,7 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) - } #[implement(Service)] -pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { +pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: ShortStateHash) { let shortroomid = self .services .short @@ -108,7 +108,7 @@ pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, } #[implement(Service)] -pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { +pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { let shortroomid = self.services.short.get_shortroomid(room_id).await?; let key: &[u64] = &[shortroomid, token]; diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index f75a212c7..cd25776a5 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -115,10 +115,10 @@ impl Data { let mut keys = Vec::new(); for (event, destination) in requests { let mut key = destination.get_prefix(); - if let SendingEvent::Pdu(value) = &event { - key.extend_from_slice(value); + if let SendingEvent::Pdu(value) = event { + key.extend(value.as_ref()); } else { - key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); + key.extend(&self.services.globals.next_count().unwrap().to_be_bytes()); } let value = if let SendingEvent::Edu(value) = &event { &**value @@ -175,7 +175,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se ( Destination::Appservice(server), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { SendingEvent::Edu(value.to_vec()) }, @@ -202,7 +202,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se ( Destination::Push(user_id, pushkey_string), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { // I'm pretty sure this should never be called SendingEvent::Edu(value.to_vec()) @@ -225,7 +225,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se .map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?, ), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { SendingEvent::Edu(value.to_vec()) }, diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index ea2668837..77997f697 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -24,7 +24,10 @@ pub use self::{ dest::Destination, sender::{EDU_LIMIT, PDU_LIMIT}, }; -use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_keys, users, Dep}; +use crate::{ + account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId, server_keys, users, + Dep, +}; pub struct Service { server: Arc, @@ -61,9 +64,9 @@ struct Msg { #[allow(clippy::module_name_repetitions)] #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SendingEvent { - Pdu(Vec), // pduid - Edu(Vec), // pdu json - Flush, // none + Pdu(RawPduId), // pduid + Edu(Vec), // pdu json + Flush, // none } #[async_trait] @@ -110,9 +113,9 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] - pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + pub fn send_pdu_push(&self, pdu_id: &RawPduId, user: &UserId, pushkey: String) -> Result { let dest = Destination::Push(user.to_owned(), pushkey); - let event = SendingEvent::Pdu(pdu_id.to_owned()); + let event = SendingEvent::Pdu(*pdu_id); let _cork = self.db.db.cork(); let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { @@ -123,7 +126,7 @@ impl Service { } #[tracing::instrument(skip(self), level = "debug")] - pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: RawPduId) -> Result { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); let _cork = self.db.db.cork(); @@ -136,7 +139,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] - pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &RawPduId) -> Result { let servers = self .services .state_cache @@ -147,13 +150,13 @@ impl Service { } #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] - pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &[u8]) -> Result<()> + pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &RawPduId) -> Result where S: Stream + Send + 'a, { let _cork = self.db.db.cork(); let requests = servers - .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.into()))) + .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))) .collect::>() .await; diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index d9087d443..464d186b7 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -536,7 +536,8 @@ impl Service { &events .iter() .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Edu(b) => &**b, + SendingEvent::Pdu(b) => b.as_ref(), SendingEvent::Flush => &[], }) .collect::>(), @@ -660,7 +661,8 @@ impl Service { &events .iter() .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Edu(b) => &**b, + SendingEvent::Pdu(b) => b.as_ref(), SendingEvent::Flush => &[], }) .collect::>(), From 137e3008ea04d36f9562eeadc61b276032fd2ddf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 6 Nov 2024 21:02:23 +0000 Subject: [PATCH 177/245] merge rooms threads data and service Signed-off-by: Jason Volk --- src/api/client/relations.rs | 12 ++-- src/api/client/threads.rs | 10 ++- src/service/rooms/pdu_metadata/data.rs | 15 ++--- src/service/rooms/threads/data.rs | 90 -------------------------- src/service/rooms/threads/mod.rs | 88 +++++++++++++++++++------ 5 files changed, 91 insertions(+), 124 deletions(-) delete mode 100644 src/service/rooms/threads/data.rs diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index ef7035e2f..b5d1485bd 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -97,7 +97,7 @@ async fn paginate_relations_with_filter( filter_event_type: Option, filter_rel_type: Option, from: Option<&str>, to: Option<&str>, limit: Option, recurse: bool, dir: Direction, ) -> Result { - let from: PduCount = from + let start: PduCount = from .map(str::parse) .transpose()? .unwrap_or_else(|| match dir { @@ -124,7 +124,7 @@ async fn paginate_relations_with_filter( let events: Vec = services .rooms .pdu_metadata - .get_relations(sender_user, room_id, target, from, limit, depth, dir) + .get_relations(sender_user, room_id, target, start, limit, depth, dir) .await .into_iter() .filter(|(_, pdu)| { @@ -146,16 +146,20 @@ async fn paginate_relations_with_filter( .await; let next_batch = match dir { - Direction::Backward => events.first(), Direction::Forward => events.last(), + Direction::Backward => events.first(), } .map(at!(0)) + .map(|count| match dir { + Direction::Forward => count.saturating_add(1), + Direction::Backward => count.saturating_sub(1), + }) .as_ref() .map(ToString::to_string); Ok(get_relating_events::v1::Response { next_batch, - prev_batch: Some(from.to_string()), + prev_batch: from.map(Into::into), recursion_depth: recurse.then_some(depth.into()), chunk: events .into_iter() diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 02cf79926..8d4e399bb 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,5 +1,5 @@ use axum::extract::State; -use conduit::{PduCount, PduEvent}; +use conduit::{at, PduCount, PduEvent}; use futures::StreamExt; use ruma::{api::client::threads::get_threads, uint}; @@ -44,12 +44,16 @@ pub(crate) async fn get_threads_route( Ok(get_threads::v1::Response { next_batch: threads .last() - .map(|(count, _)| count) + .filter(|_| threads.len() >= limit) + .map(at!(0)) + .map(|count| count.saturating_sub(1)) + .as_ref() .map(ToString::to_string), chunk: threads .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) .collect(), }) } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 3fc065915..f3e1ced8b 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,5 +1,6 @@ use std::{mem::size_of, sync::Arc}; +use arrayvec::ArrayVec; use conduit::{ result::LogErr, utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, @@ -54,15 +55,13 @@ impl Data { pub(super) fn get_relations<'a>( &'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { - let current: RawPduId = PduId { - shortroomid, - shorteventid: from, - } - .into(); - + let mut current = ArrayVec::::new(); + current.extend(target.to_be_bytes()); + current.extend(from.into_unsigned().to_be_bytes()); + let current = current.as_slice(); match dir { - Direction::Forward => self.tofrom_relation.raw_keys_from(¤t).boxed(), - Direction::Backward => self.tofrom_relation.rev_raw_keys_from(¤t).boxed(), + Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(), + Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(), } .ignore_err() .ready_take_while(move |key| key.starts_with(&target.to_be_bytes())) diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs deleted file mode 100644 index c26dabb40..000000000 --- a/src/service/rooms/threads/data.rs +++ /dev/null @@ -1,90 +0,0 @@ -use std::sync::Arc; - -use conduit::{ - result::LogErr, - utils::{stream::TryIgnore, ReadyExt}, - PduCount, PduEvent, Result, -}; -use database::{Deserialized, Map}; -use futures::{Stream, StreamExt}; -use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; - -use crate::{ - rooms, - rooms::{ - short::ShortRoomId, - timeline::{PduId, RawPduId}, - }, - Dep, -}; - -pub(super) struct Data { - threadid_userids: Arc, - services: Services, -} - -struct Services { - short: Dep, - timeline: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - threadid_userids: db["threadid_userids"].clone(), - services: Services { - short: args.depend::("rooms::short"), - timeline: args.depend::("rooms::timeline"), - }, - } - } - - #[inline] - pub(super) async fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, _include: &'a IncludeThreads, - ) -> Result + Send + 'a> { - let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?; - - let current: RawPduId = PduId { - shortroomid, - shorteventid: until.saturating_sub(1), - } - .into(); - - let stream = self - .threadid_userids - .rev_raw_keys_from(¤t) - .ignore_err() - .map(RawPduId::from) - .ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes()) - .filter_map(move |pdu_id| async move { - let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; - let pdu_id: PduId = pdu_id.into(); - - if pdu.sender != user_id { - pdu.remove_transaction_id().log_err().ok(); - } - - Some((pdu_id.shorteventid, pdu)) - }); - - Ok(stream) - } - - pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result { - let users = participants - .iter() - .map(|user| user.as_bytes()) - .collect::>() - .join(&[0xFF][..]); - - self.threadid_userids.insert(root_id, &users); - - Ok(()) - } - - pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result> { - self.threadid_userids.get(root_id).await.deserialized() - } -} diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index 025030307..fcc629e1c 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -1,34 +1,44 @@ -mod data; - use std::{collections::BTreeMap, sync::Arc}; -use conduit::{err, PduCount, PduEvent, Result}; -use data::Data; -use futures::Stream; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, PduEvent, PduId, RawPduId, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{ api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue, - EventId, RoomId, UserId, + EventId, OwnedUserId, RoomId, UserId, }; use serde_json::json; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortRoomId, Dep}; pub struct Service { - services: Services, db: Data, + services: Services, } struct Services { + short: Dep, timeline: Dep, } +pub(super) struct Data { + threadid_userids: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + db: Data { + threadid_userids: args.db["threadid_userids"].clone(), + }, services: Services { + short: args.depend::("rooms::short"), timeline: args.depend::("rooms::timeline"), }, - db: Data::new(&args), })) } @@ -36,14 +46,6 @@ impl crate::Service for Service { } impl Service { - pub async fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, include: &'a IncludeThreads, - ) -> Result + Send + 'a> { - self.db - .threads_until(user_id, room_id, until, include) - .await - } - pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { let root_id = self .services @@ -113,13 +115,61 @@ impl Service { } let mut users = Vec::new(); - if let Ok(userids) = self.db.get_participants(&root_id).await { + if let Ok(userids) = self.get_participants(&root_id).await { users.extend_from_slice(&userids); } else { users.push(root_pdu.sender); } users.push(pdu.sender.clone()); - self.db.update_participants(&root_id, &users) + self.update_participants(&root_id, &users) + } + + pub async fn threads_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, shorteventid: PduCount, _inc: &'a IncludeThreads, + ) -> Result + Send + 'a> { + let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?; + + let current: RawPduId = PduId { + shortroomid, + shorteventid, + } + .into(); + + let stream = self + .db + .threadid_userids + .rev_raw_keys_from(¤t) + .ignore_err() + .map(RawPduId::from) + .ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes()) + .filter_map(move |pdu_id| async move { + let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; + let pdu_id: PduId = pdu_id.into(); + + if pdu.sender != user_id { + pdu.remove_transaction_id().ok(); + } + + Some((pdu_id.shorteventid, pdu)) + }); + + Ok(stream) + } + + pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result { + let users = participants + .iter() + .map(|user| user.as_bytes()) + .collect::>() + .join(&[0xFF][..]); + + self.db.threadid_userids.insert(root_id, &users); + + Ok(()) + } + + pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result> { + self.db.threadid_userids.get(root_id).await.deserialized() } } From 26c890d5ac18adf98109b4663c9eecdc289badef Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 4 Nov 2024 22:38:12 +0000 Subject: [PATCH 178/245] skip redundant receipts on syncs Signed-off-by: Jason Volk --- src/api/client/sync/v3.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 080489026..2ac0bfea8 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -28,7 +28,7 @@ use ruma::{ events::{ presence::PresenceEvent, room::member::{MembershipState, RoomMemberEventContent}, - AnyRawAccountDataEvent, StateEventType, + AnyRawAccountDataEvent, AnySyncEphemeralRoomEvent, StateEventType, TimelineEventType::*, }, serde::Raw, @@ -983,20 +983,22 @@ async fn load_joined_room( .collect() .await; - let mut edus: Vec<_> = services + let edus: HashMap> = services .rooms .read_receipt .readreceipts_since(room_id, since) - .filter_map(|(read_user, _, v)| async move { - (!services + .filter_map(|(read_user, _, edu)| async move { + services .users .user_is_ignored(&read_user, sender_user) - .await) - .then_some(v) + .await + .or_some((read_user, edu)) }) .collect() .await; + let mut edus: Vec> = edus.into_values().collect(); + if services.rooms.typing.last_typing_update(room_id).await? > since { edus.push( serde_json::from_str( From 3ed2c17f980497a3ea7bdf2d438b5da7984572fd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 6 Nov 2024 01:24:44 +0000 Subject: [PATCH 179/245] move sync watcher from globals service to sync service Signed-off-by: Jason Volk --- src/api/client/sync/v3.rs | 2 +- src/api/client/sync/v4.rs | 2 +- src/service/globals/data.rs | 140 +----------------------------------- src/service/globals/mod.rs | 6 +- src/service/sync/mod.rs | 51 ++++++++++++- src/service/sync/watch.rs | 117 ++++++++++++++++++++++++++++++ 6 files changed, 170 insertions(+), 148 deletions(-) create mode 100644 src/service/sync/watch.rs diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 2ac0bfea8..00976c78c 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -93,7 +93,7 @@ pub(crate) async fn sync_events_route( } // Setup watchers, so if there's no response, we can wait for them - let watcher = services.globals.watch(&sender_user, &sender_device); + let watcher = services.sync.watch(&sender_user, &sender_device); let next_batch = services.globals.current_count()?; let next_batchcount = PduCount::Normal(next_batch); diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 11e3830cc..91abd24e9 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -51,7 +51,7 @@ pub(crate) async fn sync_events_v4_route( let sender_device = body.sender_device.expect("user is authenticated"); let mut body = body.body; // Setup watchers, so if there's no response, we can wait for them - let watcher = services.globals.watch(sender_user, &sender_device); + let watcher = services.sync.watch(sender_user, &sender_device); let next_batch = services.globals.next_count()?; diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index eea7597a0..bcfe101ef 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,35 +1,12 @@ use std::sync::{Arc, RwLock}; -use conduit::{trace, utils, Result, Server}; +use conduit::{utils, Result}; use database::{Database, Deserialized, Map}; -use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; -use ruma::{DeviceId, UserId}; - -use crate::{rooms, Dep}; pub struct Data { global: Arc, - todeviceid_events: Arc, - userroomid_joined: Arc, - userroomid_invitestate: Arc, - userroomid_leftstate: Arc, - userroomid_notificationcount: Arc, - userroomid_highlightcount: Arc, - pduid_pdu: Arc, - keychangeid_userid: Arc, - roomusertype_roomuserdataid: Arc, - readreceiptid_readreceipt: Arc, - userid_lastonetimekeyupdate: Arc, counter: RwLock, pub(super) db: Arc, - services: Services, -} - -struct Services { - server: Arc, - short: Dep, - state_cache: Dep, - typing: Dep, } const COUNTER: &[u8] = b"c"; @@ -39,25 +16,8 @@ impl Data { let db = &args.db; Self { global: db["global"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), - userroomid_invitestate: db["userroomid_invitestate"].clone(), - userroomid_leftstate: db["userroomid_leftstate"].clone(), - userroomid_notificationcount: db["userroomid_notificationcount"].clone(), - userroomid_highlightcount: db["userroomid_highlightcount"].clone(), - pduid_pdu: db["pduid_pdu"].clone(), - keychangeid_userid: db["keychangeid_userid"].clone(), - roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), - readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), - userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")), db: args.db.clone(), - services: Services { - server: args.server.clone(), - short: args.depend::("rooms::short"), - state_cache: args.depend::("rooms::state_cache"), - typing: args.depend::("rooms::typing"), - }, } } @@ -98,104 +58,6 @@ impl Data { .map_or(Ok(0_u64), utils::u64_from_bytes) } - #[tracing::instrument(skip(self), level = "debug")] - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let userid_bytes = user_id.as_bytes().to_vec(); - let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xFF); - - let mut userdeviceid_prefix = userid_prefix.clone(); - userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xFF); - - let mut futures = FuturesUnordered::new(); - - // Return when *any* user changed their key - // TODO: only send for user they share a room with - futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); - - futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_notificationcount - .watch_prefix(&userid_prefix), - ); - futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); - - // Events for rooms we are in - let rooms_joined = self.services.state_cache.rooms_joined(user_id); - - pin_mut!(rooms_joined); - while let Some(room_id) = rooms_joined.next().await { - let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else { - continue; - }; - - let roomid_bytes = room_id.as_bytes().to_vec(); - let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xFF); - - // Key changes - futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); - - // Room account data - let mut roomuser_prefix = roomid_prefix.clone(); - roomuser_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), - ); - - // PDUs - let short_roomid = short_roomid.to_be_bytes().to_vec(); - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - let typing_room_id = room_id.to_owned(); - let typing_wait_for_update = async move { - self.services.typing.wait_for_update(&typing_room_id).await; - }; - - futures.push(typing_wait_for_update.boxed()); - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); - } - - let mut globaluserdata_prefix = vec![0xFF]; - globaluserdata_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&globaluserdata_prefix), - ); - - // More key changes (used when user is not joined to any rooms) - futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); - - // One time keys - futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); - - // Server shutdown - let server_shutdown = async move { - while self.services.server.running() { - self.services.server.signal.subscribe().recv().await.ok(); - } - }; - - futures.push(server_shutdown.boxed()); - if !self.services.server.running() { - return Ok(()); - } - - // Wait until one of them finds something - trace!(futures = futures.len(), "watch started"); - futures.next().await; - trace!(futures = futures.len(), "watch finished"); - - Ok(()) - } - pub async fn database_version(&self) -> u64 { self.global .get(b"version") diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index bd9569642..55dd10aab 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -12,7 +12,7 @@ use data::Data; use ipaddress::IPAddress; use regex::RegexSet; use ruma::{ - api::client::discovery::discover_support::ContactRole, DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, + api::client::discovery::discover_support::ContactRole, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId, }; use tokio::sync::Mutex; @@ -163,10 +163,6 @@ impl Service { #[inline] pub fn current_count(&self) -> Result { Ok(self.db.current_count()) } - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.watch(user_id, device_id).await - } - #[inline] pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index 1bf4610ff..f1a6ae75e 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -1,9 +1,12 @@ +mod watch; + use std::{ collections::{BTreeMap, BTreeSet}, sync::{Arc, Mutex, Mutex as StdMutex}, }; -use conduit::Result; +use conduit::{Result, Server}; +use database::Map; use ruma::{ api::client::sync::sync_events::{ self, @@ -12,10 +15,35 @@ use ruma::{ OwnedDeviceId, OwnedRoomId, OwnedUserId, }; +use crate::{rooms, Dep}; + pub struct Service { + db: Data, + services: Services, connections: DbConnections, } +pub struct Data { + todeviceid_events: Arc, + userroomid_joined: Arc, + userroomid_invitestate: Arc, + userroomid_leftstate: Arc, + userroomid_notificationcount: Arc, + userroomid_highlightcount: Arc, + pduid_pdu: Arc, + keychangeid_userid: Arc, + roomusertype_roomuserdataid: Arc, + readreceiptid_readreceipt: Arc, + userid_lastonetimekeyupdate: Arc, +} + +struct Services { + server: Arc, + short: Dep, + state_cache: Dep, + typing: Dep, +} + struct SlidingSyncCache { lists: BTreeMap, subscriptions: BTreeMap, @@ -28,8 +56,27 @@ type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); type DbConnectionsVal = Arc>; impl crate::Service for Service { - fn build(_args: crate::Args<'_>) -> Result> { + fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + db: Data { + todeviceid_events: args.db["todeviceid_events"].clone(), + userroomid_joined: args.db["userroomid_joined"].clone(), + userroomid_invitestate: args.db["userroomid_invitestate"].clone(), + userroomid_leftstate: args.db["userroomid_leftstate"].clone(), + userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), + userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(), + pduid_pdu: args.db["pduid_pdu"].clone(), + keychangeid_userid: args.db["keychangeid_userid"].clone(), + roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(), + readreceiptid_readreceipt: args.db["readreceiptid_readreceipt"].clone(), + userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), + }, + services: Services { + server: args.server.clone(), + short: args.depend::("rooms::short"), + state_cache: args.depend::("rooms::state_cache"), + typing: args.depend::("rooms::typing"), + }, connections: StdMutex::new(BTreeMap::new()), })) } diff --git a/src/service/sync/watch.rs b/src/service/sync/watch.rs new file mode 100644 index 000000000..3eb663c12 --- /dev/null +++ b/src/service/sync/watch.rs @@ -0,0 +1,117 @@ +use conduit::{implement, trace, Result}; +use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; +use ruma::{DeviceId, UserId}; + +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result { + let userid_bytes = user_id.as_bytes().to_vec(); + let mut userid_prefix = userid_bytes.clone(); + userid_prefix.push(0xFF); + + let mut userdeviceid_prefix = userid_prefix.clone(); + userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); + userdeviceid_prefix.push(0xFF); + + let mut futures = FuturesUnordered::new(); + + // Return when *any* user changed their key + // TODO: only send for user they share a room with + futures.push(self.db.todeviceid_events.watch_prefix(&userdeviceid_prefix)); + + futures.push(self.db.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push(self.db.userroomid_invitestate.watch_prefix(&userid_prefix)); + futures.push(self.db.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push( + self.db + .userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); + futures.push( + self.db + .userroomid_highlightcount + .watch_prefix(&userid_prefix), + ); + + // Events for rooms we are in + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else { + continue; + }; + + let roomid_bytes = room_id.as_bytes().to_vec(); + let mut roomid_prefix = roomid_bytes.clone(); + roomid_prefix.push(0xFF); + + // Key changes + futures.push(self.db.keychangeid_userid.watch_prefix(&roomid_prefix)); + + // Room account data + let mut roomuser_prefix = roomid_prefix.clone(); + roomuser_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.db + .roomusertype_roomuserdataid + .watch_prefix(&roomuser_prefix), + ); + + // PDUs + let short_roomid = short_roomid.to_be_bytes().to_vec(); + futures.push(self.db.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + let typing_room_id = room_id.to_owned(); + let typing_wait_for_update = async move { + self.services.typing.wait_for_update(&typing_room_id).await; + }; + + futures.push(typing_wait_for_update.boxed()); + futures.push( + self.db + .readreceiptid_readreceipt + .watch_prefix(&roomid_prefix), + ); + } + + let mut globaluserdata_prefix = vec![0xFF]; + globaluserdata_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.db + .roomusertype_roomuserdataid + .watch_prefix(&globaluserdata_prefix), + ); + + // More key changes (used when user is not joined to any rooms) + futures.push(self.db.keychangeid_userid.watch_prefix(&userid_prefix)); + + // One time keys + futures.push( + self.db + .userid_lastonetimekeyupdate + .watch_prefix(&userid_bytes), + ); + + // Server shutdown + let server_shutdown = async move { + while self.services.server.running() { + self.services.server.signal.subscribe().recv().await.ok(); + } + }; + + futures.push(server_shutdown.boxed()); + if !self.services.server.running() { + return Ok(()); + } + + // Wait until one of them finds something + trace!(futures = futures.len(), "watch started"); + futures.next().await; + trace!(futures = futures.len(), "watch finished"); + + Ok(()) +} From 7450c654ae37e8caa9af80f40dc674d9a65893b7 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 6 Nov 2024 06:20:39 +0000 Subject: [PATCH 180/245] add get_pdu_owned sans Arc; improve client/room/event handler Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 3 +- src/api/client/membership.rs | 4 ++- src/api/client/room.rs | 48 +++++++++++++----------------- src/api/mod.rs | 2 ++ src/api/server/send_join.rs | 8 +++-- src/service/rooms/timeline/data.rs | 15 ++++++---- src/service/rooms/timeline/mod.rs | 8 ++++- 7 files changed, 50 insertions(+), 38 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 754c98408..f9d4a521f 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -6,7 +6,7 @@ use std::{ }; use conduit::{debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, Result}; -use futures::StreamExt; +use futures::{FutureExt, StreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, events::room::message::RoomMessageEventContent, @@ -246,6 +246,7 @@ pub(super) async fn get_remote_pdu( .rooms .timeline .backfill_pdu(&server, response.pdu) + .boxed() .await?; let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index fa71c0c85..bf8e5c33b 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -374,7 +374,9 @@ pub(crate) async fn invite_user_route( return Ok(invite_user::v3::Response {}); } - invite_helper(&services, sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; + invite_helper(&services, sender_user, user_id, &body.room_id, body.reason.clone(), false) + .boxed() + .await?; Ok(invite_user::v3::Response {}) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 4224d3fa7..b6683ef4d 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -2,7 +2,7 @@ use std::{cmp::max, collections::BTreeMap}; use axum::extract::State; use conduit::{debug_info, debug_warn, err, Err}; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -486,34 +486,28 @@ pub(crate) async fn create_room_route( /// - You have to currently be joined to the room (TODO: Respect history /// visibility) pub(crate) async fn get_room_event_route( - State(services): State, body: Ruma, + State(services): State, ref body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let event = services - .rooms - .timeline - .get_pdu(&body.event_id) - .await - .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; - - if !services - .rooms - .state_accessor - .user_can_see_event(sender_user, &event.room_id, &body.event_id) - .await - { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this event.", - )); - } - - let mut event = (*event).clone(); - event.add_age()?; - Ok(get_room_event::v3::Response { - event: event.to_room_event(), + event: services + .rooms + .timeline + .get_pdu_owned(&body.event_id) + .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id)))) + .and_then(|event| async move { + services + .rooms + .state_accessor + .user_can_see_event(body.sender_user(), &event.room_id, &body.event_id) + .await + .then_some(event) + .ok_or_else(|| err!(Request(Forbidden("You don't have permission to view this event.")))) + }) + .map_ok(|mut event| { + event.add_age().ok(); + event.to_room_event() + }) + .await?, }) } diff --git a/src/api/mod.rs b/src/api/mod.rs index ed8aacf23..fc68af5b5 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::toplevel_ref_arg)] + pub mod client; pub mod router; pub mod server; diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index f2ede9d0a..60ec8c1f4 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -253,7 +253,9 @@ pub(crate) async fn create_join_event_v1_route( } } - let room_state = create_join_event(&services, body.origin(), &body.room_id, &body.pdu).await?; + let room_state = create_join_event(&services, body.origin(), &body.room_id, &body.pdu) + .boxed() + .await?; Ok(create_join_event::v1::Response { room_state, @@ -296,7 +298,9 @@ pub(crate) async fn create_join_event_v2_route( auth_chain, state, event, - } = create_join_event(&services, body.origin(), &body.room_id, &body.pdu).await?; + } = create_join_event(&services, body.origin(), &body.room_id, &body.pdu) + .boxed() + .await?; let room_state = create_join_event::v2::RoomState { members_omitted: false, auth_chain, diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 19dc5325a..f062e7e49 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -123,15 +123,18 @@ impl Data { /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result> { + self.get_pdu_owned(event_id).await.map(Arc::new) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub(super) async fn get_pdu_owned(&self, event_id: &EventId) -> Result { if let Ok(pdu) = self.get_non_outlier_pdu(event_id).await { - return Ok(Arc::new(pdu)); + return Ok(pdu); } - self.eventid_outlierpdu - .get(event_id) - .await - .deserialized() - .map(Arc::new) + self.eventid_outlierpdu.get(event_id).await.deserialized() } /// Like get_non_outlier_pdu(), but without the expense of fetching and diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 86a479195..8255be7df 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -244,6 +244,11 @@ impl Service { /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. pub async fn get_pdu(&self, event_id: &EventId) -> Result> { self.db.get_pdu(event_id).await } + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub async fn get_pdu_owned(&self, event_id: &EventId) -> Result { self.db.get_pdu_owned(event_id).await } + /// Checks if pdu exists /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. @@ -885,6 +890,7 @@ impl Service { vec![(*pdu.event_id).to_owned()], state_lock, ) + .boxed() .await?; // We set the room state after inserting the pdu, so that we never have a moment @@ -1104,7 +1110,7 @@ impl Service { match response { Ok(response) => { for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(backfill_server, pdu).await { + if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await { warn!("Failed to add backfilled pdu in room {room_id}: {e}"); } } From f36757027eacc27f47f6415d998be6cf61cc4f0a Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 6 Nov 2024 18:27:40 +0000 Subject: [PATCH 181/245] split api/client/room Signed-off-by: Jason Volk --- src/api/client/room/aliases.rs | 40 +++ src/api/client/{room.rs => room/create.rs} | 359 +-------------------- src/api/client/room/event.rs | 38 +++ src/api/client/room/mod.rs | 9 + src/api/client/room/upgrade.rs | 294 +++++++++++++++++ 5 files changed, 388 insertions(+), 352 deletions(-) create mode 100644 src/api/client/room/aliases.rs rename src/api/client/{room.rs => room/create.rs} (65%) create mode 100644 src/api/client/room/event.rs create mode 100644 src/api/client/room/mod.rs create mode 100644 src/api/client/room/upgrade.rs diff --git a/src/api/client/room/aliases.rs b/src/api/client/room/aliases.rs new file mode 100644 index 000000000..e530b2602 --- /dev/null +++ b/src/api/client/room/aliases.rs @@ -0,0 +1,40 @@ +use axum::extract::State; +use conduit::{Error, Result}; +use futures::StreamExt; +use ruma::api::client::{error::ErrorKind, room::aliases}; + +use crate::Ruma; + +/// # `GET /_matrix/client/r0/rooms/{roomId}/aliases` +/// +/// Lists all aliases of the room. +/// +/// - Only users joined to the room are allowed to call this, or if +/// `history_visibility` is world readable in the room +pub(crate) async fn get_room_aliases_route( + State(services): State, body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if !services + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id) + .await + { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "You don't have permission to view this room.", + )); + } + + Ok(aliases::v3::Response { + aliases: services + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .map(ToOwned::to_owned) + .collect() + .await, + }) +} diff --git a/src/api/client/room.rs b/src/api/client/room/create.rs similarity index 65% rename from src/api/client/room.rs rename to src/api/client/room/create.rs index b6683ef4d..2ccb1c87a 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room/create.rs @@ -1,12 +1,12 @@ -use std::{cmp::max, collections::BTreeMap}; +use std::collections::BTreeMap; use axum::extract::State; -use conduit::{debug_info, debug_warn, err, Err}; -use futures::{FutureExt, StreamExt, TryFutureExt}; +use conduit::{debug_info, debug_warn, error, info, pdu::PduBuilder, warn, Err, Error, Result}; +use futures::FutureExt; use ruma::{ api::client::{ error::ErrorKind, - room::{self, aliases, create_room, get_room_event, upgrade_room}, + room::{self, create_room}, }, events::{ room::{ @@ -18,36 +18,18 @@ use ruma::{ member::{MembershipState, RoomMemberEventContent}, name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent, - tombstone::RoomTombstoneEventContent, topic::RoomTopicEventContent, }, - StateEventType, TimelineEventType, + TimelineEventType, }, int, serde::{JsonObject, Raw}, CanonicalJsonObject, Int, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomVersionId, }; use serde_json::{json, value::to_raw_value}; -use tracing::{error, info, warn}; +use service::{appservice::RegistrationInfo, Services}; -use super::invite_helper; -use crate::{ - service::{appservice::RegistrationInfo, pdu::PduBuilder, Services}, - Error, Result, Ruma, -}; - -/// Recommended transferable state events list from the spec -const TRANSFERABLE_STATE_EVENTS: &[StateEventType; 9] = &[ - StateEventType::RoomServerAcl, - StateEventType::RoomEncryption, - StateEventType::RoomName, - StateEventType::RoomAvatar, - StateEventType::RoomTopic, - StateEventType::RoomGuestAccess, - StateEventType::RoomHistoryVisibility, - StateEventType::RoomJoinRules, - StateEventType::RoomPowerLevels, -]; +use crate::{client::invite_helper, Ruma}; /// # `POST /_matrix/client/v3/createRoom` /// @@ -479,333 +461,6 @@ pub(crate) async fn create_room_route( Ok(create_room::v3::Response::new(room_id)) } -/// # `GET /_matrix/client/r0/rooms/{roomId}/event/{eventId}` -/// -/// Gets a single event. -/// -/// - You have to currently be joined to the room (TODO: Respect history -/// visibility) -pub(crate) async fn get_room_event_route( - State(services): State, ref body: Ruma, -) -> Result { - Ok(get_room_event::v3::Response { - event: services - .rooms - .timeline - .get_pdu_owned(&body.event_id) - .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id)))) - .and_then(|event| async move { - services - .rooms - .state_accessor - .user_can_see_event(body.sender_user(), &event.room_id, &body.event_id) - .await - .then_some(event) - .ok_or_else(|| err!(Request(Forbidden("You don't have permission to view this event.")))) - }) - .map_ok(|mut event| { - event.add_age().ok(); - event.to_room_event() - }) - .await?, - }) -} - -/// # `GET /_matrix/client/r0/rooms/{roomId}/aliases` -/// -/// Lists all aliases of the room. -/// -/// - Only users joined to the room are allowed to call this, or if -/// `history_visibility` is world readable in the room -pub(crate) async fn get_room_aliases_route( - State(services): State, body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if !services - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id) - .await - { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - - Ok(aliases::v3::Response { - aliases: services - .rooms - .alias - .local_aliases_for_room(&body.room_id) - .map(ToOwned::to_owned) - .collect() - .await, - }) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/upgrade` -/// -/// Upgrades the room. -/// -/// - Creates a replacement room -/// - Sends a tombstone event into the current room -/// - Sender user joins the room -/// - Transfers some state events -/// - Moves local aliases -/// - Modifies old room power levels to prevent users from speaking -pub(crate) async fn upgrade_room_route( - State(services): State, body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if !services - .globals - .supported_room_versions() - .contains(&body.new_version) - { - return Err(Error::BadRequest( - ErrorKind::UnsupportedRoomVersion, - "This server does not support that room version.", - )); - } - - // Create a replacement room - let replacement_room = RoomId::new(services.globals.server_name()); - - let _short_id = services - .rooms - .short - .get_or_create_shortroomid(&replacement_room) - .await; - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - - // Send a m.room.tombstone event to the old room to indicate that it is not - // intended to be used any further Fail if the sender does not have the required - // permissions - let tombstone_event_id = services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state( - String::new(), - &RoomTombstoneEventContent { - body: "This room has been replaced".to_owned(), - replacement_room: replacement_room.clone(), - }, - ), - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - // Change lock to replacement room - drop(state_lock); - let state_lock = services.rooms.state.mutex.lock(&replacement_room).await; - - // Get the old room creation event - let mut create_event_content: CanonicalJsonObject = services - .rooms - .state_accessor - .room_state_get_content(&body.room_id, &StateEventType::RoomCreate, "") - .await - .map_err(|_| err!(Database("Found room without m.room.create event.")))?; - - // Use the m.room.tombstone event as the predecessor - let predecessor = Some(ruma::events::room::create::PreviousRoom::new( - body.room_id.clone(), - (*tombstone_event_id).to_owned(), - )); - - // Send a m.room.create event containing a predecessor field and the applicable - // room_version - { - use RoomVersionId::*; - match body.new_version { - V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { - create_event_content.insert( - "creator".into(), - json!(&sender_user).try_into().map_err(|e| { - info!("Error forming creation event: {e}"); - Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") - })?, - ); - }, - _ => { - // "creator" key no longer exists in V11+ rooms - create_event_content.remove("creator"); - }, - } - } - - create_event_content.insert( - "room_version".into(), - json!(&body.new_version) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, - ); - create_event_content.insert( - "predecessor".into(), - json!(predecessor) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, - ); - - // Validate creation event content - if serde_json::from_str::( - to_raw_value(&create_event_content) - .expect("Error forming creation event") - .get(), - ) - .is_err() - { - return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")); - } - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&create_event_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; - - // Join the new room - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services.users.displayname(sender_user).await.ok(), - avatar_url: services.users.avatar_url(sender_user).await.ok(), - is_direct: None, - third_party_invite: None, - blurhash: services.users.blurhash(sender_user).await.ok(), - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; - - // Replicate transferable state events to the new room - for event_type in TRANSFERABLE_STATE_EVENTS { - let event_content = match services - .rooms - .state_accessor - .room_state_get(&body.room_id, event_type, "") - .await - { - Ok(v) => v.content.clone(), - Err(_) => continue, // Skipping missing events. - }; - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: event_type.to_string().into(), - content: event_content, - state_key: Some(String::new()), - ..Default::default() - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; - } - - // Moves any local aliases to the new room - let mut local_aliases = services - .rooms - .alias - .local_aliases_for_room(&body.room_id) - .boxed(); - - while let Some(alias) = local_aliases.next().await { - services - .rooms - .alias - .remove_alias(alias, sender_user) - .await?; - - services - .rooms - .alias - .set_alias(alias, &replacement_room, sender_user)?; - } - - // Get the old room power levels - let power_levels_event_content: RoomPowerLevelsEventContent = services - .rooms - .state_accessor - .room_state_get_content(&body.room_id, &StateEventType::RoomPowerLevels, "") - .await - .map_err(|_| err!(Database("Found room without m.room.power_levels event.")))?; - - // Setting events_default and invite to the greater of 50 and users_default + 1 - let new_level = max( - int!(50), - power_levels_event_content - .users_default - .checked_add(int!(1)) - .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?, - ); - - // Modify the power levels in the old room to prevent sending of events and - // inviting new users - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state( - String::new(), - &RoomPowerLevelsEventContent { - events_default: new_level, - invite: new_level, - ..power_levels_event_content - }, - ), - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - drop(state_lock); - - // Return the replacement room id - Ok(upgrade_room::v3::Response { - replacement_room, - }) -} - /// creates the power_levels_content for the PDU builder fn default_power_levels_content( power_level_content_override: Option<&Raw>, visibility: &room::Visibility, diff --git a/src/api/client/room/event.rs b/src/api/client/room/event.rs new file mode 100644 index 000000000..0f44f25d2 --- /dev/null +++ b/src/api/client/room/event.rs @@ -0,0 +1,38 @@ +use axum::extract::State; +use conduit::{err, Result}; +use futures::TryFutureExt; +use ruma::api::client::room::get_room_event; + +use crate::Ruma; + +/// # `GET /_matrix/client/r0/rooms/{roomId}/event/{eventId}` +/// +/// Gets a single event. +/// +/// - You have to currently be joined to the room (TODO: Respect history +/// visibility) +pub(crate) async fn get_room_event_route( + State(services): State, ref body: Ruma, +) -> Result { + Ok(get_room_event::v3::Response { + event: services + .rooms + .timeline + .get_pdu_owned(&body.event_id) + .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id)))) + .and_then(|event| async move { + services + .rooms + .state_accessor + .user_can_see_event(body.sender_user(), &event.room_id, &body.event_id) + .await + .then_some(event) + .ok_or_else(|| err!(Request(Forbidden("You don't have permission to view this event.")))) + }) + .map_ok(|mut event| { + event.add_age().ok(); + event.to_room_event() + }) + .await?, + }) +} diff --git a/src/api/client/room/mod.rs b/src/api/client/room/mod.rs new file mode 100644 index 000000000..fa2d168f0 --- /dev/null +++ b/src/api/client/room/mod.rs @@ -0,0 +1,9 @@ +mod aliases; +mod create; +mod event; +mod upgrade; + +pub(crate) use self::{ + aliases::get_room_aliases_route, create::create_room_route, event::get_room_event_route, + upgrade::upgrade_room_route, +}; diff --git a/src/api/client/room/upgrade.rs b/src/api/client/room/upgrade.rs new file mode 100644 index 000000000..ad5c356e8 --- /dev/null +++ b/src/api/client/room/upgrade.rs @@ -0,0 +1,294 @@ +use std::cmp::max; + +use axum::extract::State; +use conduit::{err, info, pdu::PduBuilder, Error, Result}; +use futures::StreamExt; +use ruma::{ + api::client::{error::ErrorKind, room::upgrade_room}, + events::{ + room::{ + member::{MembershipState, RoomMemberEventContent}, + power_levels::RoomPowerLevelsEventContent, + tombstone::RoomTombstoneEventContent, + }, + StateEventType, TimelineEventType, + }, + int, CanonicalJsonObject, RoomId, RoomVersionId, +}; +use serde_json::{json, value::to_raw_value}; + +use crate::Ruma; + +/// Recommended transferable state events list from the spec +const TRANSFERABLE_STATE_EVENTS: &[StateEventType; 9] = &[ + StateEventType::RoomServerAcl, + StateEventType::RoomEncryption, + StateEventType::RoomName, + StateEventType::RoomAvatar, + StateEventType::RoomTopic, + StateEventType::RoomGuestAccess, + StateEventType::RoomHistoryVisibility, + StateEventType::RoomJoinRules, + StateEventType::RoomPowerLevels, +]; + +/// # `POST /_matrix/client/r0/rooms/{roomId}/upgrade` +/// +/// Upgrades the room. +/// +/// - Creates a replacement room +/// - Sends a tombstone event into the current room +/// - Sender user joins the room +/// - Transfers some state events +/// - Moves local aliases +/// - Modifies old room power levels to prevent users from speaking +pub(crate) async fn upgrade_room_route( + State(services): State, body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if !services + .globals + .supported_room_versions() + .contains(&body.new_version) + { + return Err(Error::BadRequest( + ErrorKind::UnsupportedRoomVersion, + "This server does not support that room version.", + )); + } + + // Create a replacement room + let replacement_room = RoomId::new(services.globals.server_name()); + + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&replacement_room) + .await; + + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + // Send a m.room.tombstone event to the old room to indicate that it is not + // intended to be used any further Fail if the sender does not have the required + // permissions + let tombstone_event_id = services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder::state( + String::new(), + &RoomTombstoneEventContent { + body: "This room has been replaced".to_owned(), + replacement_room: replacement_room.clone(), + }, + ), + sender_user, + &body.room_id, + &state_lock, + ) + .await?; + + // Change lock to replacement room + drop(state_lock); + let state_lock = services.rooms.state.mutex.lock(&replacement_room).await; + + // Get the old room creation event + let mut create_event_content: CanonicalJsonObject = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomCreate, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))?; + + // Use the m.room.tombstone event as the predecessor + let predecessor = Some(ruma::events::room::create::PreviousRoom::new( + body.room_id.clone(), + (*tombstone_event_id).to_owned(), + )); + + // Send a m.room.create event containing a predecessor field and the applicable + // room_version + { + use RoomVersionId::*; + match body.new_version { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + create_event_content.insert( + "creator".into(), + json!(&sender_user).try_into().map_err(|e| { + info!("Error forming creation event: {e}"); + Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") + })?, + ); + }, + _ => { + // "creator" key no longer exists in V11+ rooms + create_event_content.remove("creator"); + }, + } + } + + create_event_content.insert( + "room_version".into(), + json!(&body.new_version) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + ); + create_event_content.insert( + "predecessor".into(), + json!(predecessor) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + ); + + // Validate creation event content + if serde_json::from_str::( + to_raw_value(&create_event_content) + .expect("Error forming creation event") + .get(), + ) + .is_err() + { + return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")); + } + + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomCreate, + content: to_raw_value(&create_event_content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(String::new()), + redacts: None, + timestamp: None, + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; + + // Join the new room + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + is_direct: None, + third_party_invite: None, + blurhash: services.users.blurhash(sender_user).await.ok(), + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; + + // Replicate transferable state events to the new room + for event_type in TRANSFERABLE_STATE_EVENTS { + let event_content = match services + .rooms + .state_accessor + .room_state_get(&body.room_id, event_type, "") + .await + { + Ok(v) => v.content.clone(), + Err(_) => continue, // Skipping missing events. + }; + + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: event_type.to_string().into(), + content: event_content, + state_key: Some(String::new()), + ..Default::default() + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; + } + + // Moves any local aliases to the new room + let mut local_aliases = services + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .boxed(); + + while let Some(alias) = local_aliases.next().await { + services + .rooms + .alias + .remove_alias(alias, sender_user) + .await?; + + services + .rooms + .alias + .set_alias(alias, &replacement_room, sender_user)?; + } + + // Get the old room power levels + let power_levels_event_content: RoomPowerLevelsEventContent = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("Found room without m.room.power_levels event.")))?; + + // Setting events_default and invite to the greater of 50 and users_default + 1 + let new_level = max( + int!(50), + power_levels_event_content + .users_default + .checked_add(int!(1)) + .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?, + ); + + // Modify the power levels in the old room to prevent sending of events and + // inviting new users + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder::state( + String::new(), + &RoomPowerLevelsEventContent { + events_default: new_level, + invite: new_level, + ..power_levels_event_content + }, + ), + sender_user, + &body.room_id, + &state_lock, + ) + .await?; + + drop(state_lock); + + // Return the replacement room id + Ok(upgrade_room::v3::Response { + replacement_room, + }) +} From e507c3130673099692143a59adc30a414ef6ca54 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 6 Nov 2024 22:21:51 +0000 Subject: [PATCH 182/245] make pdu batch tokens zeroith-indexed Signed-off-by: Jason Volk --- src/api/client/context.rs | 24 +++++++++++++++--------- src/api/client/message.rs | 15 ++++++++++----- src/api/client/relations.rs | 5 +---- src/api/client/sync/mod.rs | 2 +- src/api/client/sync/v3.rs | 15 +++++---------- src/api/server/backfill.rs | 2 +- src/core/pdu/count.rs | 19 +++++++++++++++++++ src/service/rooms/timeline/data.rs | 18 +++++++----------- src/service/rooms/timeline/mod.rs | 19 ++++++++----------- 9 files changed, 67 insertions(+), 52 deletions(-) diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 5b492cb19..d07f6ac1d 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -2,7 +2,7 @@ use std::iter::once; use axum::extract::State; use conduit::{ - err, error, + at, err, error, utils::{future::TryExtExt, stream::ReadyExt, IterStream}, Err, Result, }; @@ -82,7 +82,7 @@ pub(crate) async fn get_context_route( let events_before: Vec<_> = services .rooms .timeline - .pdus_until(sender_user, room_id, base_token) + .pdus_rev(sender_user, room_id, base_token.saturating_sub(1)) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) @@ -94,7 +94,7 @@ pub(crate) async fn get_context_route( let events_after: Vec<_> = services .rooms .timeline - .pdus_after(sender_user, room_id, base_token) + .pdus(sender_user, room_id, base_token.saturating_add(1)) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) @@ -168,22 +168,28 @@ pub(crate) async fn get_context_route( start: events_before .last() - .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) - .into(), + .map(at!(0)) + .map(|count| count.saturating_sub(1)) + .as_ref() + .map(ToString::to_string), end: events_after .last() - .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) - .into(), + .map(at!(0)) + .map(|count| count.saturating_add(1)) + .as_ref() + .map(ToString::to_string), events_before: events_before .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) .collect(), events_after: events_after .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) .collect(), state, diff --git a/src/api/client/message.rs b/src/api/client/message.rs index cb261a7f2..e76325aa2 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -100,14 +100,14 @@ pub(crate) async fn get_message_events_route( Direction::Forward => services .rooms .timeline - .pdus_after(sender_user, room_id, from) + .pdus(sender_user, room_id, from) .await? .boxed(), Direction::Backward => services .rooms .timeline - .pdus_until(sender_user, room_id, from) + .pdus_rev(sender_user, room_id, from) .await? .boxed(), }; @@ -136,7 +136,12 @@ pub(crate) async fn get_message_events_route( .collect() .await; - let next_token = events.last().map(|(count, _)| count).copied(); + let start_token = events.first().map(at!(0)).unwrap_or(from); + + let next_token = events + .last() + .map(at!(0)) + .map(|count| count.saturating_inc(body.dir)); if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { @@ -154,8 +159,8 @@ pub(crate) async fn get_message_events_route( .collect(); Ok(get_message_events::v3::Response { - start: from.to_string(), - end: next_token.as_ref().map(PduCount::to_string), + start: start_token.to_string(), + end: next_token.as_ref().map(ToString::to_string), chunk, state, }) diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index b5d1485bd..ee62dbfc9 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -150,10 +150,7 @@ async fn paginate_relations_with_filter( Direction::Backward => events.first(), } .map(at!(0)) - .map(|count| match dir { - Direction::Forward => count.saturating_add(1), - Direction::Backward => count.saturating_sub(1), - }) + .map(|count| count.saturating_inc(dir)) .as_ref() .map(ToString::to_string); diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index 7aec7186f..f047d1761 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -24,7 +24,7 @@ async fn load_timeline( let mut non_timeline_pdus = services .rooms .timeline - .pdus_until(sender_user, room_id, PduCount::max()) + .pdus_rev(sender_user, room_id, PduCount::max()) .await? .ready_take_while(|(pducount, _)| *pducount > roomsincecount); diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 00976c78c..ea487d8e2 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -6,7 +6,7 @@ use std::{ use axum::extract::State; use conduit::{ - err, error, extract_variant, is_equal_to, + at, err, error, extract_variant, is_equal_to, result::FlatOk, utils::{math::ruma_from_u64, BoolExt, IterStream, ReadyExt, TryFutureExtExt}, PduCount, @@ -945,15 +945,10 @@ async fn load_joined_room( let prev_batch = timeline_pdus .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - }, - PduCount::Normal(c) => c.to_string(), - })) - })?; + .map(at!(0)) + .map(|count| count.saturating_sub(1)) + .as_ref() + .map(ToString::to_string); let room_events: Vec<_> = timeline_pdus .iter() diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 281bf2a23..47f02841b 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -51,7 +51,7 @@ pub(crate) async fn get_backfill_route( let pdus = services .rooms .timeline - .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) + .pdus_rev(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) .await? .take(limit) .filter_map(|(_, pdu)| async move { diff --git a/src/core/pdu/count.rs b/src/core/pdu/count.rs index 90e552e89..aceec1e8d 100644 --- a/src/core/pdu/count.rs +++ b/src/core/pdu/count.rs @@ -2,6 +2,8 @@ use std::{cmp::Ordering, fmt, fmt::Display, str::FromStr}; +use ruma::api::Direction; + use crate::{err, Error, Result}; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] @@ -54,6 +56,14 @@ impl PduCount { } } + #[inline] + pub fn checked_inc(self, dir: Direction) -> Result { + match dir { + Direction::Forward => self.checked_add(1), + Direction::Backward => self.checked_sub(1), + } + } + #[inline] pub fn checked_add(self, add: u64) -> Result { Ok(match self { @@ -82,6 +92,15 @@ impl PduCount { }) } + #[inline] + #[must_use] + pub fn saturating_inc(self, dir: Direction) -> Self { + match dir { + Direction::Forward => self.saturating_add(1), + Direction::Backward => self.saturating_sub(1), + } + } + #[inline] #[must_use] pub fn saturating_add(self, add: u64) -> Self { diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index f062e7e49..f320e6a0b 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -62,7 +62,7 @@ impl Data { { hash_map::Entry::Occupied(o) => Ok(*o.get()), hash_map::Entry::Vacant(v) => Ok(self - .pdus_until(sender_user, room_id, PduCount::max()) + .pdus_rev(sender_user, room_id, PduCount::max()) .await? .next() .await @@ -201,10 +201,10 @@ impl Data { /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - pub(super) async fn pdus_until<'a>( + pub(super) async fn pdus_rev<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { - let current = self.count_to_id(room_id, until, true).await?; + let current = self.count_to_id(room_id, until).await?; let prefix = current.shortroomid(); let stream = self .pduid_pdu @@ -216,10 +216,10 @@ impl Data { Ok(stream) } - pub(super) async fn pdus_after<'a>( + pub(super) async fn pdus<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { - let current = self.count_to_id(room_id, from, false).await?; + let current = self.count_to_id(room_id, from).await?; let prefix = current.shortroomid(); let stream = self .pduid_pdu @@ -266,7 +266,7 @@ impl Data { } } - async fn count_to_id(&self, room_id: &RoomId, count: PduCount, subtract: bool) -> Result { + async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount) -> Result { let shortroomid: ShortRoomId = self .services .short @@ -277,11 +277,7 @@ impl Data { // +1 so we don't send the base event let pdu_id = PduId { shortroomid, - shorteventid: if subtract { - count.checked_sub(1)? - } else { - count.checked_add(1)? - }, + shorteventid, }; Ok(pdu_id.into()) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 8255be7df..81d372d7a 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -177,7 +177,7 @@ impl Service { #[tracing::instrument(skip(self), level = "debug")] pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result> { - self.pdus_until(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) + self.pdus_rev(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) .await? .next() .await @@ -976,26 +976,23 @@ impl Service { pub async fn all_pdus<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, ) -> Result + Send + 'a> { - self.pdus_after(user_id, room_id, PduCount::min()).await + self.pdus(user_id, room_id, PduCount::min()).await } - /// Returns an iterator over all events and their tokens in a room that - /// happened before the event with id `until` in reverse-chronological - /// order. + /// Reverse iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] - pub async fn pdus_until<'a>( + pub async fn pdus_rev<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { - self.db.pdus_until(user_id, room_id, until).await + self.db.pdus_rev(user_id, room_id, until).await } - /// Returns an iterator over all events and their token in a room that - /// happened after the event with id `from` in chronological order. + /// Forward iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] - pub async fn pdus_after<'a>( + pub async fn pdus<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { - self.db.pdus_after(user_id, room_id, from).await + self.db.pdus(user_id, room_id, from).await } /// Replace a PDU with the redacted form. From 79c6b518605da12a65a7c0ae3a769931c6eed93b Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 7 Nov 2024 03:30:47 +0000 Subject: [PATCH 183/245] renames for core pdu Signed-off-by: Jason Volk --- src/core/pdu/content.rs | 4 ++-- src/core/pdu/count.rs | 36 +++++++++++++++++++------------ src/core/pdu/event.rs | 4 ++-- src/core/pdu/filter.rs | 10 ++++----- src/core/pdu/id.rs | 12 +++++------ src/core/pdu/mod.rs | 15 +++++++------ src/core/pdu/raw_id.rs | 46 ++++++++++++++++++---------------------- src/core/pdu/redact.rs | 10 ++++----- src/core/pdu/relation.rs | 2 +- src/core/pdu/strip.rs | 20 ++++++++--------- src/core/pdu/tests.rs | 10 ++++----- src/core/pdu/unsigned.rs | 43 ++++++++++++++++++++++++++++++------- 12 files changed, 123 insertions(+), 89 deletions(-) diff --git a/src/core/pdu/content.rs b/src/core/pdu/content.rs index a6d86554b..fa724cb2d 100644 --- a/src/core/pdu/content.rs +++ b/src/core/pdu/content.rs @@ -4,13 +4,13 @@ use serde_json::value::Value as JsonValue; use crate::{err, implement, Result}; #[must_use] -#[implement(super::PduEvent)] +#[implement(super::Pdu)] pub fn get_content_as_value(&self) -> JsonValue { self.get_content() .expect("pdu content must be a valid JSON value") } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] pub fn get_content(&self) -> Result where T: for<'de> Deserialize<'de>, diff --git a/src/core/pdu/count.rs b/src/core/pdu/count.rs index aceec1e8d..852223825 100644 --- a/src/core/pdu/count.rs +++ b/src/core/pdu/count.rs @@ -7,12 +7,12 @@ use ruma::api::Direction; use crate::{err, Error, Result}; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] -pub enum PduCount { +pub enum Count { Normal(u64), Backfilled(i64), } -impl PduCount { +impl Count { #[inline] #[must_use] pub fn from_unsigned(unsigned: u64) -> Self { Self::from_signed(unsigned as i64) } @@ -69,11 +69,11 @@ impl PduCount { Ok(match self { Self::Normal(i) => Self::Normal( i.checked_add(add) - .ok_or_else(|| err!(Arithmetic("PduCount::Normal overflow")))?, + .ok_or_else(|| err!(Arithmetic("Count::Normal overflow")))?, ), Self::Backfilled(i) => Self::Backfilled( i.checked_add(add as i64) - .ok_or_else(|| err!(Arithmetic("PduCount::Backfilled overflow")))?, + .ok_or_else(|| err!(Arithmetic("Count::Backfilled overflow")))?, ), }) } @@ -83,11 +83,11 @@ impl PduCount { Ok(match self { Self::Normal(i) => Self::Normal( i.checked_sub(sub) - .ok_or_else(|| err!(Arithmetic("PduCount::Normal underflow")))?, + .ok_or_else(|| err!(Arithmetic("Count::Normal underflow")))?, ), Self::Backfilled(i) => Self::Backfilled( i.checked_sub(sub as i64) - .ok_or_else(|| err!(Arithmetic("PduCount::Backfilled underflow")))?, + .ok_or_else(|| err!(Arithmetic("Count::Backfilled underflow")))?, ), }) } @@ -121,11 +121,11 @@ impl PduCount { #[inline] #[must_use] - pub fn min() -> Self { Self::Backfilled(i64::MIN) } + pub const fn min() -> Self { Self::Backfilled(i64::MIN) } #[inline] #[must_use] - pub fn max() -> Self { Self::Normal(i64::MAX as u64) } + pub const fn max() -> Self { Self::Normal(i64::MAX as u64) } #[inline] pub(crate) fn debug_assert_valid(&self) { @@ -135,7 +135,7 @@ impl PduCount { } } -impl Display for PduCount { +impl Display for Count { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { self.debug_assert_valid(); match self { @@ -145,20 +145,30 @@ impl Display for PduCount { } } -impl FromStr for PduCount { +impl From for Count { + #[inline] + fn from(signed: i64) -> Self { Self::from_signed(signed) } +} + +impl From for Count { + #[inline] + fn from(unsigned: u64) -> Self { Self::from_unsigned(unsigned) } +} + +impl FromStr for Count { type Err = Error; fn from_str(token: &str) -> Result { Ok(Self::from_signed(token.parse()?)) } } -impl PartialOrd for PduCount { +impl PartialOrd for Count { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for PduCount { +impl Ord for Count { fn cmp(&self, other: &Self) -> Ordering { self.into_signed().cmp(&other.into_signed()) } } -impl Default for PduCount { +impl Default for Count { fn default() -> Self { Self::Normal(0) } } diff --git a/src/core/pdu/event.rs b/src/core/pdu/event.rs index 15117f925..96a1e4ba3 100644 --- a/src/core/pdu/event.rs +++ b/src/core/pdu/event.rs @@ -4,9 +4,9 @@ pub use ruma::state_res::Event; use ruma::{events::TimelineEventType, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId}; use serde_json::value::RawValue as RawJsonValue; -use super::PduEvent; +use super::Pdu; -impl Event for PduEvent { +impl Event for Pdu { type Id = Arc; fn event_id(&self) -> &Self::Id { &self.event_id } diff --git a/src/core/pdu/filter.rs b/src/core/pdu/filter.rs index bd232ebd8..c7c7316d1 100644 --- a/src/core/pdu/filter.rs +++ b/src/core/pdu/filter.rs @@ -3,7 +3,7 @@ use serde_json::Value; use crate::{implement, is_equal_to}; -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[must_use] pub fn matches(&self, filter: &RoomEventFilter) -> bool { if !self.matches_sender(filter) { @@ -25,7 +25,7 @@ pub fn matches(&self, filter: &RoomEventFilter) -> bool { true } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] fn matches_room(&self, filter: &RoomEventFilter) -> bool { if filter.not_rooms.contains(&self.room_id) { return false; @@ -40,7 +40,7 @@ fn matches_room(&self, filter: &RoomEventFilter) -> bool { true } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] fn matches_sender(&self, filter: &RoomEventFilter) -> bool { if filter.not_senders.contains(&self.sender) { return false; @@ -55,7 +55,7 @@ fn matches_sender(&self, filter: &RoomEventFilter) -> bool { true } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] fn matches_type(&self, filter: &RoomEventFilter) -> bool { let event_type = &self.kind.to_cow_str(); if filter.not_types.iter().any(is_equal_to!(event_type)) { @@ -71,7 +71,7 @@ fn matches_type(&self, filter: &RoomEventFilter) -> bool { true } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] fn matches_url(&self, filter: &RoomEventFilter) -> bool { let Some(url_filter) = filter.url_filter.as_ref() else { return true; diff --git a/src/core/pdu/id.rs b/src/core/pdu/id.rs index 05d11904c..0b23a29f8 100644 --- a/src/core/pdu/id.rs +++ b/src/core/pdu/id.rs @@ -1,4 +1,4 @@ -use super::{PduCount, RawPduId}; +use super::{Count, RawId}; use crate::utils::u64_from_u8x8; pub type ShortRoomId = ShortId; @@ -6,17 +6,17 @@ pub type ShortEventId = ShortId; pub type ShortId = u64; #[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct PduId { +pub struct Id { pub shortroomid: ShortRoomId, - pub shorteventid: PduCount, + pub shorteventid: Count, } -impl From for PduId { +impl From for Id { #[inline] - fn from(raw: RawPduId) -> Self { + fn from(raw: RawId) -> Self { Self { shortroomid: u64_from_u8x8(raw.shortroomid()), - shorteventid: PduCount::from_unsigned(u64_from_u8x8(raw.shorteventid())), + shorteventid: Count::from_unsigned(u64_from_u8x8(raw.shorteventid())), } } } diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index c785c99ea..2aa60ed1e 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -22,17 +22,18 @@ use serde_json::value::RawValue as RawJsonValue; pub use self::{ builder::{Builder, Builder as PduBuilder}, - count::PduCount, + count::Count, event::Event, event_id::*, id::*, raw_id::*, + Count as PduCount, Id as PduId, Pdu as PduEvent, RawId as RawPduId, }; use crate::Result; /// Persistent Data Unit (Event) #[derive(Clone, Deserialize, Serialize, Debug)] -pub struct PduEvent { +pub struct Pdu { pub event_id: Arc, pub room_id: OwnedRoomId, pub sender: OwnedUserId, @@ -64,7 +65,7 @@ pub struct EventHash { pub sha256: String, } -impl PduEvent { +impl Pdu { pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { let event_id = CanonicalJsonValue::String(event_id.into()); json.insert("event_id".into(), event_id); @@ -75,19 +76,19 @@ impl PduEvent { } /// Prevent derived equality which wouldn't limit itself to event_id -impl Eq for PduEvent {} +impl Eq for Pdu {} /// Equality determined by the Pdu's ID, not the memory representations. -impl PartialEq for PduEvent { +impl PartialEq for Pdu { fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id } } /// Ordering determined by the Pdu's ID, not the memory representations. -impl PartialOrd for PduEvent { +impl PartialOrd for Pdu { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } /// Ordering determined by the Pdu's ID, not the memory representations. -impl Ord for PduEvent { +impl Ord for Pdu { fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) } } diff --git a/src/core/pdu/raw_id.rs b/src/core/pdu/raw_id.rs index faba1cbf1..ef8502f68 100644 --- a/src/core/pdu/raw_id.rs +++ b/src/core/pdu/raw_id.rs @@ -1,27 +1,27 @@ use arrayvec::ArrayVec; -use super::{PduCount, PduId, ShortEventId, ShortId, ShortRoomId}; +use super::{Count, Id, ShortEventId, ShortId, ShortRoomId}; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub enum RawPduId { - Normal(RawPduIdNormal), - Backfilled(RawPduIdBackfilled), +pub enum RawId { + Normal(RawIdNormal), + Backfilled(RawIdBackfilled), } -type RawPduIdNormal = [u8; RawPduId::NORMAL_LEN]; -type RawPduIdBackfilled = [u8; RawPduId::BACKFILLED_LEN]; +type RawIdNormal = [u8; RawId::NORMAL_LEN]; +type RawIdBackfilled = [u8; RawId::BACKFILLED_LEN]; const INT_LEN: usize = size_of::(); -impl RawPduId { +impl RawId { const BACKFILLED_LEN: usize = size_of::() + INT_LEN + size_of::(); const MAX_LEN: usize = Self::BACKFILLED_LEN; const NORMAL_LEN: usize = size_of::() + size_of::(); #[inline] #[must_use] - pub fn pdu_count(&self) -> PduCount { - let id: PduId = (*self).into(); + pub fn pdu_count(&self) -> Count { + let id: Id = (*self).into(); id.shorteventid } @@ -61,55 +61,51 @@ impl RawPduId { } } -impl AsRef<[u8]> for RawPduId { +impl AsRef<[u8]> for RawId { #[inline] fn as_ref(&self) -> &[u8] { self.as_bytes() } } -impl From<&[u8]> for RawPduId { +impl From<&[u8]> for RawId { #[inline] fn from(id: &[u8]) -> Self { match id.len() { Self::NORMAL_LEN => Self::Normal( id[0..Self::NORMAL_LEN] .try_into() - .expect("normal RawPduId from [u8]"), + .expect("normal RawId from [u8]"), ), Self::BACKFILLED_LEN => Self::Backfilled( id[0..Self::BACKFILLED_LEN] .try_into() - .expect("backfilled RawPduId from [u8]"), + .expect("backfilled RawId from [u8]"), ), - _ => unimplemented!("unrecognized RawPduId length"), + _ => unimplemented!("unrecognized RawId length"), } } } -impl From for RawPduId { +impl From for RawId { #[inline] - fn from(id: PduId) -> Self { - const MAX_LEN: usize = RawPduId::MAX_LEN; + fn from(id: Id) -> Self { + const MAX_LEN: usize = RawId::MAX_LEN; type RawVec = ArrayVec; let mut vec = RawVec::new(); vec.extend(id.shortroomid.to_be_bytes()); id.shorteventid.debug_assert_valid(); match id.shorteventid { - PduCount::Normal(shorteventid) => { + Count::Normal(shorteventid) => { vec.extend(shorteventid.to_be_bytes()); - Self::Normal( - vec.as_ref() - .try_into() - .expect("RawVec into RawPduId::Normal"), - ) + Self::Normal(vec.as_ref().try_into().expect("RawVec into RawId::Normal")) }, - PduCount::Backfilled(shorteventid) => { + Count::Backfilled(shorteventid) => { vec.extend(0_u64.to_be_bytes()); vec.extend(shorteventid.to_be_bytes()); Self::Backfilled( vec.as_ref() .try_into() - .expect("RawVec into RawPduId::Backfilled"), + .expect("RawVec into RawId::Backfilled"), ) }, } diff --git a/src/core/pdu/redact.rs b/src/core/pdu/redact.rs index 647f54c0f..e116e563d 100644 --- a/src/core/pdu/redact.rs +++ b/src/core/pdu/redact.rs @@ -18,9 +18,9 @@ struct ExtractRedactedBecause { redacted_because: Option, } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] -pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> { +pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result { self.unsigned = None; let mut content = @@ -31,7 +31,7 @@ pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Resul self.unsigned = Some( to_raw_value(&json!({ - "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") + "redacted_because": serde_json::to_value(reason).expect("to_value(Pdu) always works") })) .expect("to string always works"), ); @@ -41,7 +41,7 @@ pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Resul Ok(()) } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[must_use] pub fn is_redacted(&self) -> bool { let Some(unsigned) = &self.unsigned else { @@ -72,7 +72,7 @@ pub fn is_redacted(&self) -> bool { /// > to the content of m.room.redaction events in older room versions when /// > serving /// > such events over the Client-Server API. -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[must_use] pub fn copy_redacts(&self) -> (Option>, Box) { if self.kind == TimelineEventType::RoomRedaction { diff --git a/src/core/pdu/relation.rs b/src/core/pdu/relation.rs index ae156a3de..2968171e3 100644 --- a/src/core/pdu/relation.rs +++ b/src/core/pdu/relation.rs @@ -13,7 +13,7 @@ struct ExtractRelatesToEventId { relates_to: ExtractRelType, } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[must_use] pub fn relation_type_equal(&self, rel_type: &RelationType) -> bool { self.get_content() diff --git a/src/core/pdu/strip.rs b/src/core/pdu/strip.rs index 8d20d9828..30fee863c 100644 --- a/src/core/pdu/strip.rs +++ b/src/core/pdu/strip.rs @@ -10,7 +10,7 @@ use serde_json::{json, value::Value as JsonValue}; use crate::{implement, warn}; -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_sync_room_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); @@ -36,7 +36,7 @@ pub fn to_sync_room_event(&self) -> Raw { } /// This only works for events that are also AnyRoomEvents. -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_any_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); @@ -62,7 +62,7 @@ pub fn to_any_event(&self) -> Raw { serde_json::from_value(json).expect("Raw::from_value always works") } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_room_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); @@ -88,7 +88,7 @@ pub fn to_room_event(&self) -> Raw { serde_json::from_value(json).expect("Raw::from_value always works") } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_message_like_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); @@ -114,7 +114,7 @@ pub fn to_message_like_event(&self) -> Raw { serde_json::from_value(json).expect("Raw::from_value always works") } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[must_use] pub fn to_state_event_value(&self) -> JsonValue { let mut json = json!({ @@ -134,13 +134,13 @@ pub fn to_state_event_value(&self) -> JsonValue { json } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_state_event(&self) -> Raw { serde_json::from_value(self.to_state_event_value()).expect("Raw::from_value always works") } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_sync_state_event(&self) -> Raw { let mut json = json!({ @@ -159,7 +159,7 @@ pub fn to_sync_state_event(&self) -> Raw { serde_json::from_value(json).expect("Raw::from_value always works") } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_stripped_state_event(&self) -> Raw { let json = json!({ @@ -172,7 +172,7 @@ pub fn to_stripped_state_event(&self) -> Raw { serde_json::from_value(json).expect("Raw::from_value always works") } -#[implement(super::PduEvent)] +#[implement(super::Pdu)] #[tracing::instrument(skip(self), level = "debug")] pub fn to_stripped_spacechild_state_event(&self) -> Raw { let json = json!({ @@ -186,7 +186,7 @@ pub fn to_stripped_spacechild_state_event(&self) -> Raw Raw> { let mut json = json!({ diff --git a/src/core/pdu/tests.rs b/src/core/pdu/tests.rs index 30ec23ba7..ae3b1dd6d 100644 --- a/src/core/pdu/tests.rs +++ b/src/core/pdu/tests.rs @@ -1,19 +1,19 @@ #![cfg(test)] -use super::PduCount; +use super::Count; #[test] fn backfilled_parse() { - let count: PduCount = "-987654".parse().expect("parse() failed"); - let backfilled = matches!(count, PduCount::Backfilled(_)); + let count: Count = "-987654".parse().expect("parse() failed"); + let backfilled = matches!(count, Count::Backfilled(_)); assert!(backfilled, "not backfilled variant"); } #[test] fn normal_parse() { - let count: PduCount = "987654".parse().expect("parse() failed"); - let backfilled = matches!(count, PduCount::Backfilled(_)); + let count: Count = "987654".parse().expect("parse() failed"); + let backfilled = matches!(count, Count::Backfilled(_)); assert!(!backfilled, "backfilled variant"); } diff --git a/src/core/pdu/unsigned.rs b/src/core/pdu/unsigned.rs index 1c47e8263..6f3e44016 100644 --- a/src/core/pdu/unsigned.rs +++ b/src/core/pdu/unsigned.rs @@ -4,10 +4,11 @@ use ruma::MilliSecondsSinceUnixEpoch; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue, Value as JsonValue}; +use super::Pdu; use crate::{err, implement, is_true, Result}; -#[implement(super::PduEvent)] -pub fn remove_transaction_id(&mut self) -> Result<()> { +#[implement(Pdu)] +pub fn remove_transaction_id(&mut self) -> Result { let Some(unsigned) = &self.unsigned else { return Ok(()); }; @@ -23,8 +24,8 @@ pub fn remove_transaction_id(&mut self) -> Result<()> { Ok(()) } -#[implement(super::PduEvent)] -pub fn add_age(&mut self) -> Result<()> { +#[implement(Pdu)] +pub fn add_age(&mut self) -> Result { let mut unsigned: BTreeMap> = self .unsigned .as_ref() @@ -44,7 +45,33 @@ pub fn add_age(&mut self) -> Result<()> { Ok(()) } -#[implement(super::PduEvent)] +#[implement(Pdu)] +pub fn add_relation(&mut self, name: &str, pdu: &Pdu) -> Result { + let mut unsigned: BTreeMap = self + .unsigned + .as_ref() + .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + let relations: &mut JsonValue = unsigned.entry("m.relations".into()).or_default(); + if relations.as_object_mut().is_none() { + let mut object = serde_json::Map::::new(); + _ = relations.as_object_mut().insert(&mut object); + } + + relations + .as_object_mut() + .expect("we just created it") + .insert(name.to_owned(), serde_json::to_value(pdu)?); + + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); + + Ok(()) +} + +#[implement(Pdu)] pub fn contains_unsigned_property(&self, property: &str, is_type: F) -> bool where F: FnOnce(&JsonValue) -> bool, @@ -55,7 +82,7 @@ where .is_some_and(is_true!()) } -#[implement(super::PduEvent)] +#[implement(Pdu)] pub fn get_unsigned_property(&self, property: &str) -> Result where T: for<'de> Deserialize<'de>, @@ -68,11 +95,11 @@ where .map_err(|e| err!(Database("Failed to deserialize unsigned.{property} into type: {e}"))) } -#[implement(super::PduEvent)] +#[implement(Pdu)] #[must_use] pub fn get_unsigned_as_value(&self) -> JsonValue { self.get_unsigned::().unwrap_or_default() } -#[implement(super::PduEvent)] +#[implement(Pdu)] pub fn get_unsigned(&self) -> Result { self.unsigned .as_ref() From 27966221f106ef2c4c4e88cc9381d9f1e2d0468e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 5 Nov 2024 04:37:08 +0000 Subject: [PATCH 184/245] add ready_try_fold to utils Signed-off-by: Jason Volk --- src/core/utils/result/inspect_log.rs | 2 ++ src/core/utils/stream/ready.rs | 15 ++++++++++++ src/core/utils/stream/try_ready.rs | 36 +++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/core/utils/result/inspect_log.rs b/src/core/utils/result/inspect_log.rs index 577761c5c..e9f32663c 100644 --- a/src/core/utils/result/inspect_log.rs +++ b/src/core/utils/result/inspect_log.rs @@ -11,6 +11,7 @@ where { fn log_err(self, level: Level) -> Self; + #[inline] fn err_log(self) -> Self where Self: Sized, @@ -25,6 +26,7 @@ where { fn log_err_debug(self, level: Level) -> Self; + #[inline] fn err_debug_log(self) -> Self where Self: Sized, diff --git a/src/core/utils/stream/ready.rs b/src/core/utils/stream/ready.rs index c16d12465..f4eec7d1b 100644 --- a/src/core/utils/stream/ready.rs +++ b/src/core/utils/stream/ready.rs @@ -32,6 +32,11 @@ where where F: Fn(T, Item) -> T; + fn ready_fold_default(self, f: F) -> Fold, T, impl FnMut(T, Item) -> Ready> + where + F: Fn(T, Item) -> T, + T: Default; + fn ready_for_each(self, f: F) -> ForEach, impl FnMut(Item) -> Ready<()>> where F: FnMut(Item); @@ -93,6 +98,15 @@ where self.fold(init, move |a, t| ready(f(a, t))) } + #[inline] + fn ready_fold_default(self, f: F) -> Fold, T, impl FnMut(T, Item) -> Ready> + where + F: Fn(T, Item) -> T, + T: Default, + { + self.ready_fold(T::default(), f) + } + #[inline] #[allow(clippy::unit_arg)] fn ready_for_each(self, mut f: F) -> ForEach, impl FnMut(Item) -> Ready<()>> @@ -120,6 +134,7 @@ where self.scan(init, move |s, t| ready(f(s, t))) } + #[inline] fn ready_scan_each( self, init: T, f: F, ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs index 3fbcbc454..0daed26e4 100644 --- a/src/core/utils/stream/try_ready.rs +++ b/src/core/utils/stream/try_ready.rs @@ -3,7 +3,7 @@ use futures::{ future::{ready, Ready}, - stream::{AndThen, TryForEach, TryStream, TryStreamExt}, + stream::{AndThen, TryFold, TryForEach, TryStream, TryStreamExt}, }; use crate::Result; @@ -25,6 +25,19 @@ where ) -> TryForEach>, impl FnMut(S::Ok) -> Ready>> where F: FnMut(S::Ok) -> Result<(), E>; + + fn ready_try_fold( + self, init: U, f: F, + ) -> TryFold>, U, impl FnMut(U, S::Ok) -> Ready>> + where + F: Fn(U, S::Ok) -> Result; + + fn ready_try_fold_default( + self, f: F, + ) -> TryFold>, U, impl FnMut(U, S::Ok) -> Ready>> + where + F: Fn(U, S::Ok) -> Result, + U: Default; } impl TryReadyExt for S @@ -49,4 +62,25 @@ where { self.try_for_each(move |t| ready(f(t))) } + + #[inline] + fn ready_try_fold( + self, init: U, f: F, + ) -> TryFold>, U, impl FnMut(U, S::Ok) -> Ready>> + where + F: Fn(U, S::Ok) -> Result, + { + self.try_fold(init, move |a, t| ready(f(a, t))) + } + + #[inline] + fn ready_try_fold_default( + self, f: F, + ) -> TryFold>, U, impl FnMut(U, S::Ok) -> Ready>> + where + F: Fn(U, S::Ok) -> Result, + U: Default, + { + self.ready_try_fold(U::default(), f) + } } From 13ef6dcbcf17e04f28ad8beaab64920e63c2aa31 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 7 Nov 2024 03:59:08 +0000 Subject: [PATCH 185/245] add standalone getters for shortid service Signed-off-by: Jason Volk --- src/service/rooms/short/mod.rs | 35 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 9fddf099e..e8b00d9bd 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -52,13 +52,7 @@ impl crate::Service for Service { pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId { const BUFSIZE: usize = size_of::(); - if let Ok(shorteventid) = self - .db - .eventid_shorteventid - .get(event_id) - .await - .deserialized() - { + if let Ok(shorteventid) = self.get_shorteventid(event_id).await { return shorteventid; } @@ -105,11 +99,10 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> } #[implement(Service)] -pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - let key = (event_type, state_key); +pub async fn get_shorteventid(&self, event_id: &EventId) -> Result { self.db - .statekey_shortstatekey - .qry(&key) + .eventid_shorteventid + .get(event_id) .await .deserialized() } @@ -118,17 +111,11 @@ pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &s pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> ShortStateKey { const BUFSIZE: usize = size_of::(); - let key = (event_type, state_key); - if let Ok(shortstatekey) = self - .db - .statekey_shortstatekey - .qry(&key) - .await - .deserialized() - { + if let Ok(shortstatekey) = self.get_shortstatekey(event_type, state_key).await { return shortstatekey; } + let key = (event_type, state_key); let shortstatekey = self.services.globals.next_count().unwrap(); debug_assert!(size_of_val(&shortstatekey) == BUFSIZE, "buffer requirement changed"); @@ -143,6 +130,16 @@ pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, sta shortstatekey } +#[implement(Service)] +pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + let key = (event_type, state_key); + self.db + .statekey_shortstatekey + .qry(&key) + .await + .deserialized() +} + #[implement(Service)] pub async fn get_eventid_from_short(&self, shorteventid: ShortEventId) -> Result> { const BUFSIZE: usize = size_of::(); From 1f2e939fd56319b85426457a9eb469228e287406 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 7 Nov 2024 04:49:01 +0000 Subject: [PATCH 186/245] optional arguments for timeline pdus iterations Signed-off-by: Jason Volk --- src/api/client/context.rs | 4 ++-- src/api/client/message.rs | 4 ++-- src/api/client/sync/mod.rs | 4 ++-- src/api/server/backfill.rs | 4 ++-- src/service/rooms/timeline/data.rs | 13 +++++++------ src/service/rooms/timeline/mod.rs | 18 +++++++++++------- 6 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/api/client/context.rs b/src/api/client/context.rs index d07f6ac1d..f5f981ba0 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -82,7 +82,7 @@ pub(crate) async fn get_context_route( let events_before: Vec<_> = services .rooms .timeline - .pdus_rev(sender_user, room_id, base_token.saturating_sub(1)) + .pdus_rev(Some(sender_user), room_id, Some(base_token.saturating_sub(1))) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) @@ -94,7 +94,7 @@ pub(crate) async fn get_context_route( let events_after: Vec<_> = services .rooms .timeline - .pdus(sender_user, room_id, base_token.saturating_add(1)) + .pdus(Some(sender_user), room_id, Some(base_token.saturating_add(1))) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index e76325aa2..e8306de9f 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -100,14 +100,14 @@ pub(crate) async fn get_message_events_route( Direction::Forward => services .rooms .timeline - .pdus(sender_user, room_id, from) + .pdus(Some(sender_user), room_id, Some(from)) .await? .boxed(), Direction::Backward => services .rooms .timeline - .pdus_rev(sender_user, room_id, from) + .pdus_rev(Some(sender_user), room_id, Some(from)) .await? .boxed(), }; diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index f047d1761..3201b8276 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -14,7 +14,7 @@ async fn load_timeline( let last_timeline_count = services .rooms .timeline - .last_timeline_count(sender_user, room_id) + .last_timeline_count(Some(sender_user), room_id) .await?; if last_timeline_count <= roomsincecount { @@ -24,7 +24,7 @@ async fn load_timeline( let mut non_timeline_pdus = services .rooms .timeline - .pdus_rev(sender_user, room_id, PduCount::max()) + .pdus_rev(Some(sender_user), room_id, None) .await? .ready_take_while(|(pducount, _)| *pducount > roomsincecount); diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 47f02841b..be770ee8a 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -6,7 +6,7 @@ use conduit::{ PduCount, Result, }; use futures::{FutureExt, StreamExt}; -use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch}; +use ruma::{api::federation::backfill::get_backfill, uint, MilliSecondsSinceUnixEpoch}; use super::AccessCheck; use crate::Ruma; @@ -51,7 +51,7 @@ pub(crate) async fn get_backfill_route( let pdus = services .rooms .timeline - .pdus_rev(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) + .pdus_rev(None, &body.room_id, Some(until)) .await? .take(limit) .filter_map(|(_, pdu)| async move { diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index f320e6a0b..7f1873ab0 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Borrow, collections::{hash_map, HashMap}, sync::Arc, }; @@ -53,7 +54,7 @@ impl Data { } } - pub(super) async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + pub(super) async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache .lock() @@ -202,7 +203,7 @@ impl Data { /// happened before the event with id `until` in reverse-chronological /// order. pub(super) async fn pdus_rev<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { let current = self.count_to_id(room_id, until).await?; let prefix = current.shortroomid(); @@ -211,13 +212,13 @@ impl Data { .rev_raw_stream_from(¤t) .ignore_err() .ready_take_while(move |(key, _)| key.starts_with(&prefix)) - .map(|item| Self::each_pdu(item, user_id)); + .map(move |item| Self::each_pdu(item, user_id)); Ok(stream) } pub(super) async fn pdus<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { let current = self.count_to_id(room_id, from).await?; let prefix = current.shortroomid(); @@ -231,13 +232,13 @@ impl Data { Ok(stream) } - fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem { + fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> PdusIterItem { let pdu_id: RawPduId = pdu_id.into(); let mut pdu = serde_json::from_slice::(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON"); - if pdu.sender != user_id { + if Some(pdu.sender.borrow()) != user_id { pdu.remove_transaction_id().log_err().ok(); } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 81d372d7a..281879d2f 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -177,7 +177,7 @@ impl Service { #[tracing::instrument(skip(self), level = "debug")] pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result> { - self.pdus_rev(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) + self.pdus_rev(None, room_id, None) .await? .next() .await @@ -186,7 +186,7 @@ impl Service { } #[tracing::instrument(skip(self), level = "debug")] - pub async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + pub async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result { self.db.last_timeline_count(sender_user, room_id).await } @@ -976,23 +976,27 @@ impl Service { pub async fn all_pdus<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, ) -> Result + Send + 'a> { - self.pdus(user_id, room_id, PduCount::min()).await + self.pdus(Some(user_id), room_id, None).await } /// Reverse iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] pub async fn pdus_rev<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: Option, ) -> Result + Send + 'a> { - self.db.pdus_rev(user_id, room_id, until).await + self.db + .pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max)) + .await } /// Forward iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] pub async fn pdus<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: Option, ) -> Result + Send + 'a> { - self.db.pdus(user_id, room_id, from).await + self.db + .pdus(user_id, room_id, from.unwrap_or_else(PduCount::min)) + .await } /// Replace a PDU with the redacted form. From f59e8af73474aad18dd68300b245fd0ce2b8ab92 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 8 Nov 2024 05:49:28 +0000 Subject: [PATCH 187/245] slight cleanup/simplifications to backfil Signed-off-by: Jason Volk --- src/api/server/backfill.rs | 76 +++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index be770ee8a..2858d9fda 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -16,7 +16,7 @@ use crate::Ruma; /// Retrieves events from before the sender joined the room, if the room's /// history visibility allows. pub(crate) async fn get_backfill_route( - State(services): State, body: Ruma, + State(services): State, ref body: Ruma, ) -> Result { AccessCheck { services: &services, @@ -27,57 +27,55 @@ pub(crate) async fn get_backfill_route( .check() .await?; - let until = body - .v - .iter() - .stream() - .filter_map(|event_id| { - services - .rooms - .timeline - .get_pdu_count(event_id) - .map(Result::ok) - }) - .ready_fold(PduCount::Backfilled(0), cmp::max) - .await; - let limit = body .limit .min(uint!(100)) .try_into() .expect("UInt could not be converted to usize"); - let origin = body.origin(); - let pdus = services - .rooms - .timeline - .pdus_rev(None, &body.room_id, Some(until)) - .await? - .take(limit) - .filter_map(|(_, pdu)| async move { - if !services - .rooms - .state_accessor - .server_can_see_event(origin, &pdu.room_id, &pdu.event_id) - .await - { - return None; - } - + let from = body + .v + .iter() + .stream() + .filter_map(|event_id| { services .rooms .timeline - .get_pdu_json(&pdu.event_id) - .await - .ok() + .get_pdu_count(event_id) + .map(Result::ok) }) - .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect() + .ready_fold(PduCount::min(), cmp::max) .await; Ok(get_backfill::v1::Response { - origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdus, + + origin: services.globals.server_name().to_owned(), + + pdus: services + .rooms + .timeline + .pdus_rev(None, &body.room_id, Some(from)) + .await? + .take(limit) + .filter_map(|(_, pdu)| async move { + services + .rooms + .state_accessor + .server_can_see_event(body.origin(), &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) + }) + .filter_map(|pdu| async move { + services + .rooms + .timeline + .get_pdu_json(&pdu.event_id) + .await + .ok() + }) + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await, }) } From 6eba36d7883439539b8ca0b65f04d2935e41ad05 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 8 Nov 2024 08:21:19 +0000 Subject: [PATCH 188/245] split make_body template Signed-off-by: Jason Volk --- src/api/router/args.rs | 47 ++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/src/api/router/args.rs b/src/api/router/args.rs index 38236db34..4c0aff4c6 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -1,8 +1,8 @@ use std::{mem, ops::Deref}; use axum::{async_trait, body::Body, extract::FromRequest}; -use bytes::{BufMut, BytesMut}; -use conduit::{debug, err, trace, utils::string::EMPTY, Error, Result}; +use bytes::{BufMut, Bytes, BytesMut}; +use conduit::{debug, err, utils::string::EMPTY, Error, Result}; use ruma::{ api::IncomingRequest, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId, }; @@ -103,7 +103,32 @@ fn make_body( where T: IncomingRequest + Send + Sync + 'static, { - let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { + let body = take_body(services, request, json_body, auth); + let http_request = into_http_request(request, body); + T::try_from_http_request(http_request, &request.path).map_err(|e| err!(Request(BadJson(debug_warn!("{e}"))))) +} + +fn into_http_request(request: &Request, body: Bytes) -> hyper::Request { + let mut http_request = hyper::Request::builder() + .uri(request.parts.uri.clone()) + .method(request.parts.method.clone()); + + *http_request.headers_mut().expect("mutable http headers") = request.parts.headers.clone(); + + let http_request = http_request.body(body).expect("http request body"); + + let headers = http_request.headers(); + let method = http_request.method(); + let uri = http_request.uri(); + debug!("{method:?} {uri:?} {headers:?}"); + + http_request +} + +fn take_body( + services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, +) -> Bytes { + if let Some(CanonicalJsonValue::Object(json_body)) = json_body { let user_id = auth.sender_user.clone().unwrap_or_else(|| { let server_name = services.globals.server_name(); UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") @@ -131,19 +156,5 @@ where buf.into_inner().freeze() } else { mem::take(&mut request.body) - }; - - let mut http_request = hyper::Request::builder() - .uri(request.parts.uri.clone()) - .method(request.parts.method.clone()); - *http_request.headers_mut().expect("mutable http headers") = request.parts.headers.clone(); - let http_request = http_request.body(body).expect("http request body"); - - let headers = http_request.headers(); - let method = http_request.method(); - let uri = http_request.uri(); - debug!("{method:?} {uri:?} {headers:?}"); - trace!("{method:?} {uri:?} {json_body:?}"); - - T::try_from_http_request(http_request, &request.path).map_err(|e| err!(Request(BadJson(debug_warn!("{e}"))))) + } } From 1ce3db727fdd298ba94dd472d017c6fe7e8a92c2 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 8 Nov 2024 07:30:52 +0000 Subject: [PATCH 189/245] split event_handler service Signed-off-by: Jason Volk --- src/service/rooms/event_handler/acl_check.rs | 35 + .../fetch_and_handle_outliers.rs | 181 +++ src/service/rooms/event_handler/fetch_prev.rs | 104 ++ .../rooms/event_handler/fetch_state.rs | 84 ++ .../event_handler/handle_incoming_pdu.rs | 172 +++ .../rooms/event_handler/handle_outlier_pdu.rs | 164 ++ .../rooms/event_handler/handle_prev_pdu.rs | 82 + src/service/rooms/event_handler/mod.rs | 1328 +---------------- .../rooms/event_handler/parse_incoming_pdu.rs | 41 +- .../rooms/event_handler/resolve_state.rs | 101 ++ .../rooms/event_handler/state_at_incoming.rs | 178 +++ .../event_handler/upgrade_outlier_pdu.rs | 298 ++++ 12 files changed, 1437 insertions(+), 1331 deletions(-) create mode 100644 src/service/rooms/event_handler/acl_check.rs create mode 100644 src/service/rooms/event_handler/fetch_and_handle_outliers.rs create mode 100644 src/service/rooms/event_handler/fetch_prev.rs create mode 100644 src/service/rooms/event_handler/fetch_state.rs create mode 100644 src/service/rooms/event_handler/handle_incoming_pdu.rs create mode 100644 src/service/rooms/event_handler/handle_outlier_pdu.rs create mode 100644 src/service/rooms/event_handler/handle_prev_pdu.rs create mode 100644 src/service/rooms/event_handler/resolve_state.rs create mode 100644 src/service/rooms/event_handler/state_at_incoming.rs create mode 100644 src/service/rooms/event_handler/upgrade_outlier_pdu.rs diff --git a/src/service/rooms/event_handler/acl_check.rs b/src/service/rooms/event_handler/acl_check.rs new file mode 100644 index 000000000..f2ff1b003 --- /dev/null +++ b/src/service/rooms/event_handler/acl_check.rs @@ -0,0 +1,35 @@ +use conduit::{debug, implement, trace, warn, Err, Result}; +use ruma::{ + events::{room::server_acl::RoomServerAclEventContent, StateEventType}, + RoomId, ServerName, +}; + +/// Returns Ok if the acl allows the server +#[implement(super::Service)] +#[tracing::instrument(skip_all)] +pub async fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result { + let Ok(acl_event_content) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomServerAcl, "") + .await + .map(|c: RoomServerAclEventContent| c) + .inspect(|acl| trace!("ACL content found: {acl:?}")) + .inspect_err(|e| trace!("No ACL content found: {e:?}")) + else { + return Ok(()); + }; + + if acl_event_content.allow.is_empty() { + warn!("Ignoring broken ACL event (allow key is empty)"); + return Ok(()); + } + + if acl_event_content.is_allowed(server_name) { + trace!("server {server_name} is allowed by ACL"); + Ok(()) + } else { + debug!("Server {server_name} was denied by room ACL in {room_id}"); + Err!(Request(Forbidden("Server was denied by room ACL"))) + } +} diff --git a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs new file mode 100644 index 000000000..677b78f21 --- /dev/null +++ b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs @@ -0,0 +1,181 @@ +use std::{ + collections::{hash_map, BTreeMap, HashSet}, + sync::Arc, + time::Instant, +}; + +use conduit::{ + debug, debug_error, implement, info, pdu, trace, utils::math::continue_exponential_backoff_secs, warn, PduEvent, +}; +use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId, ServerName}; + +/// Find the event and auth it. Once the event is validated (steps 1 - 8) +/// it is appended to the outliers Tree. +/// +/// Returns pdu and if we fetched it over federation the raw json. +/// +/// a. Look in the main timeline (pduid_pdu tree) +/// b. Look at outlier pdu tree +/// c. Ask origin server over federation +/// d. TODO: Ask other servers over federation? +#[implement(super::Service)] +pub(super) async fn fetch_and_handle_outliers<'a>( + &self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, + room_version_id: &'a RoomVersionId, +) -> Vec<(Arc, Option>)> { + let back_off = |id| match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + }; + + let mut events_with_auth_events = Vec::with_capacity(events.len()); + for id in events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { + trace!("Found {id} in db"); + events_with_auth_events.push((id, Some(local_pdu), vec![])); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); + let mut events_all = HashSet::with_capacity(todo_auth_events.len()); + while let Some(next_id) = todo_auth_events.pop() { + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(&*next_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + info!("Backing off from {next_id}"); + continue; + } + } + + if events_all.contains(&next_id) { + continue; + } + + if self.services.timeline.pdu_exists(&next_id).await { + trace!("Found {next_id} in db"); + continue; + } + + debug!("Fetching {next_id} over federation."); + match self + .services + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + include_unredacted_content: None, + }, + ) + .await + { + Ok(res) => { + debug!("Got {next_id} over federation"); + let Ok((calculated_event_id, value)) = pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) + else { + back_off((*next_id).to_owned()); + continue; + }; + + if calculated_event_id != *next_id { + warn!( + "Server didn't return event id we requested: requested: {next_id}, we got \ + {calculated_event_id}. Event: {:?}", + &res.pdu + ); + } + + if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { + for auth_event in auth_events { + if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { + let a: Arc = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); + } + } + } else { + warn!("Auth event list invalid"); + } + + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + }, + Err(e) => { + debug_error!("Failed to fetch event {next_id}: {e}"); + back_off((*next_id).to_owned()); + }, + } + } + events_with_auth_events.push((id, None, events_in_reverse_order)); + } + + let mut pdus = Vec::with_capacity(events_with_auth_events.len()); + for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Some(local_pdu) = local_pdu { + trace!("Found {id} in db"); + pdus.push((local_pdu.clone(), None)); + } + + for (next_id, value) in events_in_reverse_order.into_iter().rev() { + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(&*next_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!("Backing off from {next_id}"); + continue; + } + } + + match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true)).await + { + Ok((pdu, json)) => { + if next_id == *id { + pdus.push((pdu, Some(json))); + } + }, + Err(e) => { + warn!("Authentication of event {next_id} failed: {e:?}"); + back_off(next_id.into()); + }, + } + } + } + pdus +} diff --git a/src/service/rooms/event_handler/fetch_prev.rs b/src/service/rooms/event_handler/fetch_prev.rs new file mode 100644 index 000000000..4acdba1dc --- /dev/null +++ b/src/service/rooms/event_handler/fetch_prev.rs @@ -0,0 +1,104 @@ +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + sync::Arc, +}; + +use conduit::{debug_warn, err, implement, PduEvent, Result}; +use futures::{future, FutureExt}; +use ruma::{ + int, + state_res::{self}, + uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName, +}; + +use super::check_room_id; + +#[implement(super::Service)] +#[allow(clippy::type_complexity)] +#[tracing::instrument(skip_all)] +pub(super) async fn fetch_prev( + &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, + initial_set: Vec>, +) -> Result<( + Vec>, + HashMap, (Arc, BTreeMap)>, +)> { + let mut graph: HashMap, _> = HashMap::with_capacity(initial_set.len()); + let mut eventid_info = HashMap::new(); + let mut todo_outlier_stack: Vec> = initial_set; + + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; + + let mut amount = 0; + + while let Some(prev_event_id) = todo_outlier_stack.pop() { + self.services.server.check_running()?; + + if let Some((pdu, mut json_opt)) = self + .fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id) + .boxed() + .await + .pop() + { + check_room_id(room_id, &pdu)?; + + let limit = self.services.server.config.max_fetch_prev_events; + if amount > limit { + debug_warn!("Max prev event limit reached! Limit: {limit}"); + graph.insert(prev_event_id.clone(), HashSet::new()); + continue; + } + + if json_opt.is_none() { + json_opt = self + .services + .outlier + .get_outlier_pdu_json(&prev_event_id) + .await + .ok(); + } + + if let Some(json) = json_opt { + if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { + amount = amount.saturating_add(1); + for prev_prev in &pdu.prev_events { + if !graph.contains_key(prev_prev) { + todo_outlier_stack.push(prev_prev.clone()); + } + } + + graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect()); + } else { + // Time based check failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + + eventid_info.insert(prev_event_id.clone(), (pdu, json)); + } else { + // Get json failed, so this was not fetched over federation + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } else { + // Fetch and handle failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } + + let event_fetch = |event_id| { + let origin_server_ts = eventid_info + .get(&event_id) + .cloned() + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts); + + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + future::ok((int!(0), MilliSecondsSinceUnixEpoch(origin_server_ts))) + }; + + let sorted = state_res::lexicographical_topological_sort(&graph, &event_fetch) + .await + .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?; + + Ok((sorted, eventid_info)) +} diff --git a/src/service/rooms/event_handler/fetch_state.rs b/src/service/rooms/event_handler/fetch_state.rs new file mode 100644 index 000000000..74b0bb32a --- /dev/null +++ b/src/service/rooms/event_handler/fetch_state.rs @@ -0,0 +1,84 @@ +use std::{ + collections::{hash_map, HashMap}, + sync::Arc, +}; + +use conduit::{debug, implement, warn, Err, Error, PduEvent, Result}; +use futures::FutureExt; +use ruma::{ + api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId, RoomVersionId, ServerName, +}; + +/// Call /state_ids to find out what the state at this pdu is. We trust the +/// server's response to some extend (sic), but we still do a lot of checks +/// on the events +#[implement(super::Service)] +#[tracing::instrument(skip(self, create_event, room_version_id))] +pub(super) async fn fetch_state( + &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, + event_id: &EventId, +) -> Result>>> { + debug!("Fetching state ids"); + let res = self + .services + .sending + .send_synapse_request( + origin, + get_room_state_ids::v1::Request { + room_id: room_id.to_owned(), + event_id: (*event_id).to_owned(), + }, + ) + .await + .inspect_err(|e| warn!("Fetching state for event failed: {e}"))?; + + debug!("Fetching state events"); + let collect = res + .pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(); + + let state_vec = self + .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id) + .boxed() + .await; + + let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); + for (pdu, _) in state_vec { + let state_key = pdu + .state_key + .clone() + .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; + + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) + .await; + + match state.entry(shortstatekey) { + hash_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::bad_database( + "State event's type and state_key combination exists multiple times.", + )) + }, + } + } + + // The original create event must still be in the state + let create_shortstatekey = self + .services + .short + .get_shortstatekey(&StateEventType::RoomCreate, "") + .await?; + + if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { + return Err!(Database("Incoming event refers to wrong create event.")); + } + + Ok(Some(state)) +} diff --git a/src/service/rooms/event_handler/handle_incoming_pdu.rs b/src/service/rooms/event_handler/handle_incoming_pdu.rs new file mode 100644 index 000000000..4d2d75d5f --- /dev/null +++ b/src/service/rooms/event_handler/handle_incoming_pdu.rs @@ -0,0 +1,172 @@ +use std::{ + collections::{hash_map, BTreeMap}, + time::Instant, +}; + +use conduit::{debug, err, implement, warn, Error, Result}; +use futures::FutureExt; +use ruma::{ + api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId, +}; + +use super::{check_room_id, get_room_version_id}; +use crate::rooms::timeline::RawPduId; + +/// When receiving an event one needs to: +/// 0. Check the server is in the room +/// 1. Skip the PDU if we already know about it +/// 1.1. Remove unsigned field +/// 2. Check signatures, otherwise drop +/// 3. Check content hash, redact if doesn't match +/// 4. Fetch any missing auth events doing all checks listed here starting at 1. +/// These are not timeline events +/// 5. Reject "due to auth events" if can't get all the auth events or some of +/// the auth events are also rejected "due to auth events" +/// 6. Reject "due to auth events" if the event doesn't pass auth based on the +/// auth events +/// 7. Persist this event as an outlier +/// 8. If not timeline event: stop +/// 9. Fetch any missing prev events doing all checks listed here starting at 1. +/// These are timeline events +/// 10. Fetch missing state and auth chain events by calling `/state_ids` at +/// backwards extremities doing all the checks in this list starting at +/// 1. These are not timeline events +/// 11. Check the auth of the event passes based on the state of the event +/// 12. Ensure that the state is derived from the previous current state (i.e. +/// we calculated by doing state res where one of the inputs was a +/// previously trusted set of state, don't just trust a set of state we got +/// from a remote) +/// 13. Use state resolution to find new room state +/// 14. Check if the event passes auth based on the "current state" of the room, +/// if not soft fail it +#[implement(super::Service)] +#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")] +pub async fn handle_incoming_pdu<'a>( + &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, + value: BTreeMap, is_timeline_event: bool, +) -> Result> { + // 1. Skip the PDU if we already have it as a timeline event + if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { + return Ok(Some(pdu_id)); + } + + // 1.1 Check the server is in the room + if !self.services.metadata.exists(room_id).await { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); + } + + // 1.2 Check if the room is disabled + if self.services.metadata.is_disabled(room_id).await { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Federation of this room is currently disabled on this server.", + )); + } + + // 1.3.1 Check room ACL on origin field/server + self.acl_check(origin, room_id).await?; + + // 1.3.2 Check room ACL on sender's server name + let sender: &UserId = value + .get("sender") + .try_into() + .map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?; + + self.acl_check(sender.server_name(), room_id).await?; + + // Fetch create event + let create_event = self + .services + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await?; + + // Procure the room version + let room_version_id = get_room_version_id(&create_event)?; + + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; + + let (incoming_pdu, val) = self + .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false) + .boxed() + .await?; + + check_room_id(room_id, &incoming_pdu)?; + + // 8. if not timeline event: stop + if !is_timeline_event { + return Ok(None); + } + // Skip old events + if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + return Ok(None); + } + + // 9. Fetch any missing prev events doing all checks listed here starting at 1. + // These are timeline events + let (sorted_prev_events, mut eventid_info) = self + .fetch_prev( + origin, + &create_event, + room_id, + &room_version_id, + incoming_pdu.prev_events.clone(), + ) + .await?; + + debug!(events = ?sorted_prev_events, "Got previous events"); + for prev_id in sorted_prev_events { + self.services.server.check_running()?; + if let Err(e) = self + .handle_prev_pdu( + origin, + event_id, + room_id, + &mut eventid_info, + &create_event, + &first_pdu_in_room, + &prev_id, + ) + .await + { + use hash_map::Entry; + + let now = Instant::now(); + warn!("Prev event {prev_id} failed: {e}"); + + match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(prev_id.into()) + { + Entry::Vacant(e) => { + e.insert((now, 1)); + }, + Entry::Occupied(mut e) => { + *e.get_mut() = (now, e.get().1.saturating_add(1)); + }, + }; + } + } + + // Done with prev events, now handling the incoming event + let start_time = Instant::now(); + self.federation_handletime + .write() + .expect("locked") + .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); + + let r = self + .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id) + .await; + + self.federation_handletime + .write() + .expect("locked") + .remove(&room_id.to_owned()); + + r +} diff --git a/src/service/rooms/event_handler/handle_outlier_pdu.rs b/src/service/rooms/event_handler/handle_outlier_pdu.rs new file mode 100644 index 000000000..2d95ff637 --- /dev/null +++ b/src/service/rooms/event_handler/handle_outlier_pdu.rs @@ -0,0 +1,164 @@ +use std::{ + collections::{hash_map, BTreeMap, HashMap}, + sync::Arc, +}; + +use conduit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result}; +use futures::future::ready; +use ruma::{ + api::client::error::ErrorKind, + events::StateEventType, + state_res::{self, EventTypeExt}, + CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, ServerName, +}; + +use super::{check_room_id, get_room_version_id, to_room_version}; + +#[implement(super::Service)] +#[allow(clippy::too_many_arguments)] +pub(super) async fn handle_outlier_pdu<'a>( + &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, + mut value: CanonicalJsonObject, auth_events_known: bool, +) -> Result<(Arc, BTreeMap)> { + // 1. Remove unsigned field + value.remove("unsigned"); + + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let room_version_id = get_room_version_id(create_event)?; + let mut val = match self + .services + .server_keys + .verify_event(&value, Some(&room_version_id)) + .await + { + Ok(ruma::signatures::Verified::All) => value, + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + debug_info!("Calculated hash does not match (redaction): {event_id}"); + let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { + return Err!(Request(InvalidParam("Redaction failed"))); + }; + + // Skip the PDU if it is redacted and we already have it as an outlier event + if self.services.timeline.pdu_exists(event_id).await { + return Err!(Request(InvalidParam("Event was redacted and we already knew about it"))); + } + + obj + }, + Err(e) => { + return Err!(Request(InvalidParam(debug_error!( + "Signature verification failed for {event_id}: {e}" + )))) + }, + }; + + // Now that we have checked the signature and hashes we can add the eventID and + // convert to our PduEvent type + val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + let incoming_pdu = + serde_json::from_value::(serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue")) + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; + + check_room_id(room_id, &incoming_pdu)?; + + if !auth_events_known { + // 4. fetch any missing auth events doing all checks listed here starting at 1. + // These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of + // the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + debug!("Fetching auth events"); + Box::pin( + self.fetch_and_handle_outliers( + origin, + &incoming_pdu + .auth_events + .iter() + .map(|x| Arc::from(&**x)) + .collect::>>(), + create_event, + room_id, + &room_version_id, + ), + ) + .await; + } + + // 6. Reject "due to auth events" if the event doesn't pass auth based on the + // auth events + debug!("Checking based on auth events"); + // Build map of auth events + let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); + for id in &incoming_pdu.auth_events { + let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { + warn!("Could not find auth event {id}"); + continue; + }; + + check_room_id(room_id, &auth_event)?; + + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", + )); + }, + } + } + + // The original create event must be in the auth events + if !matches!( + auth_events + .get(&(StateEventType::RoomCreate, String::new())) + .map(AsRef::as_ref), + Some(_) | None + ) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); + } + + let state_fetch = |ty: &'static StateEventType, sk: &str| { + let key = ty.with_state_key(sk); + ready(auth_events.get(&key)) + }; + + let auth_check = state_res::event_auth::auth_check( + &to_room_version(&room_version_id), + &incoming_pdu, + None, // TODO: third party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; + + if !auth_check { + return Err!(Request(Forbidden("Auth check failed"))); + } + + trace!("Validation successful."); + + // 7. Persist the event as an outlier. + self.services + .outlier + .add_pdu_outlier(&incoming_pdu.event_id, &val); + + trace!("Added pdu as outlier."); + + Ok((Arc::new(incoming_pdu), val)) +} diff --git a/src/service/rooms/event_handler/handle_prev_pdu.rs b/src/service/rooms/event_handler/handle_prev_pdu.rs new file mode 100644 index 000000000..90ff7f06b --- /dev/null +++ b/src/service/rooms/event_handler/handle_prev_pdu.rs @@ -0,0 +1,82 @@ +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, + time::Instant, +}; + +use conduit::{debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result}; +use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, ServerName}; + +#[implement(super::Service)] +#[allow(clippy::type_complexity)] +#[allow(clippy::too_many_arguments)] +#[tracing::instrument( + skip(self, origin, event_id, room_id, eventid_info, create_event, first_pdu_in_room), + name = "prev" +)] +pub(super) async fn handle_prev_pdu<'a>( + &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, + eventid_info: &mut HashMap, (Arc, BTreeMap)>, + create_event: &Arc, first_pdu_in_room: &Arc, prev_id: &EventId, +) -> Result { + // Check for disabled again because it might have changed + if self.services.metadata.is_disabled(room_id).await { + debug!( + "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and event \ + ID {event_id}" + ); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Federation of this room is currently disabled on this server.", + )); + } + + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(prev_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!( + ?tries, + duration = ?time.elapsed(), + "Backing off from prev_event" + ); + return Ok(()); + } + } + + if let Some((pdu, json)) = eventid_info.remove(prev_id) { + // Skip old events + if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + return Ok(()); + } + + let start_time = Instant::now(); + self.federation_handletime + .write() + .expect("locked") + .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); + + self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id) + .await?; + + self.federation_handletime + .write() + .expect("locked") + .remove(&room_id.to_owned()); + + debug!( + elapsed = ?start_time.elapsed(), + "Handled prev_event", + ); + } + + Ok(()) +} diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index f76f817d3..f6440fe93 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,51 +1,34 @@ +mod acl_check; +mod fetch_and_handle_outliers; +mod fetch_prev; +mod fetch_state; +mod handle_incoming_pdu; +mod handle_outlier_pdu; +mod handle_prev_pdu; mod parse_incoming_pdu; +mod resolve_state; +mod state_at_incoming; +mod upgrade_outlier_pdu; use std::{ - borrow::Borrow, - collections::{hash_map, BTreeMap, HashMap, HashSet}, + collections::HashMap, fmt::Write, sync::{Arc, RwLock as StdRwLock}, time::Instant, }; -use conduit::{ - debug, debug_error, debug_info, debug_warn, err, info, pdu, - result::LogErr, - trace, - utils::{math::continue_exponential_backoff_secs, IterStream, MutexMap}, - warn, Err, Error, PduEvent, Result, Server, -}; -use futures::{future, future::ready, FutureExt, StreamExt, TryFutureExt}; +use conduit::{utils::MutexMap, Err, PduEvent, Result, Server}; use ruma::{ - api::{ - client::error::ErrorKind, - federation::event::{get_event, get_room_state_ids}, - }, - events::{ - room::{ - create::RoomCreateEventContent, redaction::RoomRedactionEventContent, server_acl::RoomServerAclEventContent, - }, - StateEventType, TimelineEventType, - }, - int, - state_res::{self, EventTypeExt, RoomVersion, StateMap}, - uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, - RoomId, RoomVersionId, ServerName, UserId, + events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId, OwnedRoomId, RoomId, + RoomVersionId, }; -use crate::{ - globals, rooms, - rooms::{ - state_compressor::{CompressedStateEvent, HashSetCompressStateEvent}, - timeline::RawPduId, - }, - sending, server_keys, Dep, -}; +use crate::{globals, rooms, sending, server_keys, Dep}; pub struct Service { - services: Services, - pub federation_handletime: StdRwLock, pub mutex_federation: RoomMutexMap, + pub federation_handletime: StdRwLock, + services: Services, } struct Services { @@ -70,6 +53,8 @@ type HandleTimeMap = HashMap; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + mutex_federation: RoomMutexMap::new(), + federation_handletime: HandleTimeMap::new().into(), services: Services { globals: args.depend::("globals"), sending: args.depend::("sending"), @@ -85,8 +70,6 @@ impl crate::Service for Service { timeline: args.depend::("rooms::timeline"), server: args.server.clone(), }, - federation_handletime: HandleTimeMap::new().into(), - mutex_federation: RoomMutexMap::new(), })) } @@ -108,1279 +91,6 @@ impl crate::Service for Service { } impl Service { - /// When receiving an event one needs to: - /// 0. Check the server is in the room - /// 1. Skip the PDU if we already know about it - /// 1.1. Remove unsigned field - /// 2. Check signatures, otherwise drop - /// 3. Check content hash, redact if doesn't match - /// 4. Fetch any missing auth events doing all checks listed here starting - /// at 1. These are not timeline events - /// 5. Reject "due to auth events" if can't get all the auth events or some - /// of the auth events are also rejected "due to auth events" - /// 6. Reject "due to auth events" if the event doesn't pass auth based on - /// the auth events - /// 7. Persist this event as an outlier - /// 8. If not timeline event: stop - /// 9. Fetch any missing prev events doing all checks listed here starting - /// at 1. These are timeline events - /// 10. Fetch missing state and auth chain events by calling `/state_ids` at - /// backwards extremities doing all the checks in this list starting at - /// 1. These are not timeline events - /// 11. Check the auth of the event passes based on the state of the event - /// 12. Ensure that the state is derived from the previous current state - /// (i.e. we calculated by doing state res where one of the inputs was a - /// previously trusted set of state, don't just trust a set of state we - /// got from a remote) - /// 13. Use state resolution to find new room state - /// 14. Check if the event passes auth based on the "current state" of the - /// room, if not soft fail it - #[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")] - pub async fn handle_incoming_pdu<'a>( - &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, - value: BTreeMap, is_timeline_event: bool, - ) -> Result> { - // 1. Skip the PDU if we already have it as a timeline event - if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { - return Ok(Some(pdu_id)); - } - - // 1.1 Check the server is in the room - if !self.services.metadata.exists(room_id).await { - return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); - } - - // 1.2 Check if the room is disabled - if self.services.metadata.is_disabled(room_id).await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Federation of this room is currently disabled on this server.", - )); - } - - // 1.3.1 Check room ACL on origin field/server - self.acl_check(origin, room_id).await?; - - // 1.3.2 Check room ACL on sender's server name - let sender: &UserId = value - .get("sender") - .try_into() - .map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?; - - self.acl_check(sender.server_name(), room_id).await?; - - // Fetch create event - let create_event = self - .services - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "") - .await?; - - // Procure the room version - let room_version_id = get_room_version_id(&create_event)?; - - let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; - - let (incoming_pdu, val) = self - .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false) - .boxed() - .await?; - - check_room_id(room_id, &incoming_pdu)?; - - // 8. if not timeline event: stop - if !is_timeline_event { - return Ok(None); - } - // Skip old events - if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - return Ok(None); - } - - // 9. Fetch any missing prev events doing all checks listed here starting at 1. - // These are timeline events - let (sorted_prev_events, mut eventid_info) = self - .fetch_prev( - origin, - &create_event, - room_id, - &room_version_id, - incoming_pdu.prev_events.clone(), - ) - .await?; - - debug!(events = ?sorted_prev_events, "Got previous events"); - for prev_id in sorted_prev_events { - self.services.server.check_running()?; - if let Err(e) = self - .handle_prev_pdu( - origin, - event_id, - room_id, - &mut eventid_info, - &create_event, - &first_pdu_in_room, - &prev_id, - ) - .await - { - use hash_map::Entry; - - let now = Instant::now(); - warn!("Prev event {prev_id} failed: {e}"); - - match self - .services - .globals - .bad_event_ratelimiter - .write() - .expect("locked") - .entry(prev_id.into()) - { - Entry::Vacant(e) => { - e.insert((now, 1)); - }, - Entry::Occupied(mut e) => { - *e.get_mut() = (now, e.get().1.saturating_add(1)); - }, - }; - } - } - - // Done with prev events, now handling the incoming event - let start_time = Instant::now(); - self.federation_handletime - .write() - .expect("locked") - .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); - - let r = self - .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id) - .await; - - self.federation_handletime - .write() - .expect("locked") - .remove(&room_id.to_owned()); - - r - } - - #[allow(clippy::type_complexity)] - #[allow(clippy::too_many_arguments)] - #[tracing::instrument( - skip(self, origin, event_id, room_id, eventid_info, create_event, first_pdu_in_room), - name = "prev" - )] - pub async fn handle_prev_pdu<'a>( - &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, - eventid_info: &mut HashMap, (Arc, BTreeMap)>, - create_event: &Arc, first_pdu_in_room: &Arc, prev_id: &EventId, - ) -> Result<()> { - // Check for disabled again because it might have changed - if self.services.metadata.is_disabled(room_id).await { - debug!( - "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ - event ID {event_id}" - ); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Federation of this room is currently disabled on this server.", - )); - } - - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(prev_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - debug!( - ?tries, - duration = ?time.elapsed(), - "Backing off from prev_event" - ); - return Ok(()); - } - } - - if let Some((pdu, json)) = eventid_info.remove(prev_id) { - // Skip old events - if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - return Ok(()); - } - - let start_time = Instant::now(); - self.federation_handletime - .write() - .expect("locked") - .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - - self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id) - .await?; - - self.federation_handletime - .write() - .expect("locked") - .remove(&room_id.to_owned()); - - debug!( - elapsed = ?start_time.elapsed(), - "Handled prev_event", - ); - } - - Ok(()) - } - - #[allow(clippy::too_many_arguments)] - async fn handle_outlier_pdu<'a>( - &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, - mut value: CanonicalJsonObject, auth_events_known: bool, - ) -> Result<(Arc, BTreeMap)> { - // 1. Remove unsigned field - value.remove("unsigned"); - - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - let room_version_id = get_room_version_id(create_event)?; - let mut val = match self - .services - .server_keys - .verify_event(&value, Some(&room_version_id)) - .await - { - Ok(ruma::signatures::Verified::All) => value, - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - debug_info!("Calculated hash does not match (redaction): {event_id}"); - let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { - return Err!(Request(InvalidParam("Redaction failed"))); - }; - - // Skip the PDU if it is redacted and we already have it as an outlier event - if self.services.timeline.pdu_exists(event_id).await { - return Err!(Request(InvalidParam("Event was redacted and we already knew about it"))); - } - - obj - }, - Err(e) => { - return Err!(Request(InvalidParam(debug_error!( - "Signature verification failed for {event_id}: {e}" - )))) - }, - }; - - // Now that we have checked the signature and hashes we can add the eventID and - // convert to our PduEvent type - val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - let incoming_pdu = serde_json::from_value::( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - - check_room_id(room_id, &incoming_pdu)?; - - if !auth_events_known { - // 4. fetch any missing auth events doing all checks listed here starting at 1. - // These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of - // the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often - debug!("Fetching auth events"); - Box::pin( - self.fetch_and_handle_outliers( - origin, - &incoming_pdu - .auth_events - .iter() - .map(|x| Arc::from(&**x)) - .collect::>>(), - create_event, - room_id, - &room_version_id, - ), - ) - .await; - } - - // 6. Reject "due to auth events" if the event doesn't pass auth based on the - // auth events - debug!("Checking based on auth events"); - // Build map of auth events - let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); - for id in &incoming_pdu.auth_events { - let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { - warn!("Could not find auth event {id}"); - continue; - }; - - check_room_id(room_id, &auth_event)?; - - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times.", - )); - }, - } - } - - // The original create event must be in the auth events - if !matches!( - auth_events - .get(&(StateEventType::RoomCreate, String::new())) - .map(AsRef::as_ref), - Some(_) | None - ) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Incoming event refers to wrong create event.", - )); - } - - let state_fetch = |ty: &'static StateEventType, sk: &str| { - let key = ty.with_state_key(sk); - ready(auth_events.get(&key)) - }; - - let auth_check = state_res::event_auth::auth_check( - &to_room_version(&room_version_id), - &incoming_pdu, - None, // TODO: third party invite - state_fetch, - ) - .await - .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - - if !auth_check { - return Err!(Request(Forbidden("Auth check failed"))); - } - - trace!("Validation successful."); - - // 7. Persist the event as an outlier. - self.services - .outlier - .add_pdu_outlier(&incoming_pdu.event_id, &val); - - trace!("Added pdu as outlier."); - - Ok((Arc::new(incoming_pdu), val)) - } - - pub async fn upgrade_outlier_to_timeline_pdu( - &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, - origin: &ServerName, room_id: &RoomId, - ) -> Result> { - // Skip the PDU if we already have it as a timeline event - if let Ok(pduid) = self - .services - .timeline - .get_pdu_id(&incoming_pdu.event_id) - .await - { - return Ok(Some(pduid)); - } - - if self - .services - .pdu_metadata - .is_event_soft_failed(&incoming_pdu.event_id) - .await - { - return Err!(Request(InvalidParam("Event has been soft failed"))); - } - - debug!("Upgrading to timeline pdu"); - let timer = Instant::now(); - let room_version_id = get_room_version_id(create_event)?; - - // 10. Fetch missing state and auth chain events by calling /state_ids at - // backwards extremities doing all the checks in this list starting at 1. - // These are not timeline events. - - debug!("Resolving state at event"); - let mut state_at_incoming_event = if incoming_pdu.prev_events.len() == 1 { - self.state_at_incoming_degree_one(&incoming_pdu).await? - } else { - self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id) - .await? - }; - - if state_at_incoming_event.is_none() { - state_at_incoming_event = self - .fetch_state(origin, create_event, room_id, &room_version_id, &incoming_pdu.event_id) - .await?; - } - - let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); - let room_version = to_room_version(&room_version_id); - - debug!("Performing auth check"); - // 11. Check the auth of the event passes based on the state of the event - let state_fetch_state = &state_at_incoming_event; - let state_fetch = |k: &'static StateEventType, s: String| async move { - let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; - - let event_id = state_fetch_state.get(&shortstatekey)?; - self.services.timeline.get_pdu(event_id).await.ok() - }; - - let auth_check = state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None, // TODO: third party invite - |k, s| state_fetch(k, s.to_owned()), - ) - .await - .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - - if !auth_check { - return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); - } - - debug!("Gathering auth events"); - let auth_events = self - .services - .state - .get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - ) - .await?; - - let state_fetch = |k: &'static StateEventType, s: &str| { - let key = k.with_state_key(s); - ready(auth_events.get(&key).cloned()) - }; - - let auth_check = state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None, // third-party invite - state_fetch, - ) - .await - .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - - // Soft fail check before doing state res - debug!("Performing soft-fail check"); - let soft_fail = { - use RoomVersionId::*; - - !auth_check - || incoming_pdu.kind == TimelineEventType::RoomRedaction - && match room_version_id { - V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { - if let Some(redact_id) = &incoming_pdu.redacts { - !self - .services - .state_accessor - .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) - .await? - } else { - false - } - }, - _ => { - let content: RoomRedactionEventContent = incoming_pdu.get_content()?; - if let Some(redact_id) = &content.redacts { - !self - .services - .state_accessor - .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) - .await? - } else { - false - } - }, - } - }; - - // 13. Use state resolution to find new room state - - // We start looking at current room state now, so lets lock the room - trace!("Locking the room"); - let state_lock = self.services.state.mutex.lock(room_id).await; - - // Now we calculate the set of extremities this room has after the incoming - // event has been applied. We start with the previous extremities (aka leaves) - trace!("Calculating extremities"); - let mut extremities: HashSet<_> = self - .services - .state - .get_forward_extremities(room_id) - .map(ToOwned::to_owned) - .collect() - .await; - - // Remove any forward extremities that are referenced by this incoming event's - // prev_events - trace!( - "Calculated {} extremities; checking against {} prev_events", - extremities.len(), - incoming_pdu.prev_events.len() - ); - for prev_event in &incoming_pdu.prev_events { - extremities.remove(&(**prev_event)); - } - - // Only keep those extremities were not referenced yet - let mut retained = HashSet::new(); - for id in &extremities { - if !self - .services - .pdu_metadata - .is_event_referenced(room_id, id) - .await - { - retained.insert(id.clone()); - } - } - - extremities.retain(|id| retained.contains(id)); - debug!("Retained {} extremities. Compressing state", extremities.len()); - - let mut state_ids_compressed = HashSet::new(); - for (shortstatekey, id) in &state_at_incoming_event { - state_ids_compressed.insert( - self.services - .state_compressor - .compress_state_event(*shortstatekey, id) - .await, - ); - } - - let state_ids_compressed = Arc::new(state_ids_compressed); - - if incoming_pdu.state_key.is_some() { - debug!("Event is a state-event. Deriving new room state"); - - // We also add state after incoming event to the fork states - let mut state_after = state_at_incoming_event.clone(); - if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) - .await; - - let event_id = &incoming_pdu.event_id; - state_after.insert(shortstatekey, event_id.clone()); - } - - let new_room_state = self - .resolve_state(room_id, &room_version_id, state_after) - .await?; - - // Set the new room state to the resolved state - debug!("Forcing new room state"); - let HashSetCompressStateEvent { - shortstatehash, - added, - removed, - } = self - .services - .state_compressor - .save_state(room_id, new_room_state) - .await?; - - self.services - .state - .force_state(room_id, shortstatehash, added, removed, &state_lock) - .await?; - } - - // 14. Check if the event passes auth based on the "current state" of the room, - // if not soft fail it - if soft_fail { - debug!("Soft failing event"); - self.services - .timeline - .append_incoming_pdu( - &incoming_pdu, - val, - extremities.iter().map(|e| (**e).to_owned()).collect(), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .await?; - - // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {incoming_pdu:?}"); - self.services - .pdu_metadata - .mark_event_soft_failed(&incoming_pdu.event_id); - - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); - } - - trace!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone().into()); - - // Now that the event has passed all auth it is added into the timeline. - // We use the `state_at_event` instead of `state_after` so we accurately - // represent the state for this event. - let pdu_id = self - .services - .timeline - .append_incoming_pdu( - &incoming_pdu, - val, - extremities.into_iter().collect(), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .await?; - - // Event has passed all auth/stateres checks - drop(state_lock); - debug_info!( - elapsed = ?timer.elapsed(), - "Accepted", - ); - - Ok(pdu_id) - } - - #[tracing::instrument(skip_all, name = "resolve")] - pub async fn resolve_state( - &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, - ) -> Result>> { - debug!("Loading current room state ids"); - let current_sstatehash = self - .services - .state - .get_room_shortstatehash(room_id) - .await - .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; - - let current_state_ids = self - .services - .state_accessor - .state_full_ids(current_sstatehash) - .await?; - - let fork_states = [current_state_ids, incoming_state]; - let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); - for state in &fork_states { - let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); - - let auth_chain: HashSet> = self - .services - .auth_chain - .get_event_ids(room_id, &starting_events) - .await? - .into_iter() - .collect(); - - auth_chain_sets.push(auth_chain); - } - - debug!("Loading fork states"); - let fork_states: Vec>> = fork_states - .into_iter() - .stream() - .then(|fork_state| { - fork_state - .into_iter() - .stream() - .filter_map(|(k, id)| { - self.services - .short - .get_statekey_from_short(k) - .map_ok_or_else(|_| None, move |(ty, st_key)| Some(((ty, st_key), id))) - }) - .collect() - }) - .collect() - .boxed() - .await; - - debug!("Resolving state"); - let lock = self.services.globals.stateres_mutex.lock(); - - let event_fetch = |event_id| self.event_fetch(event_id); - let event_exists = |event_id| self.event_exists(event_id); - let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) - .await - .map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?; - - drop(lock); - - debug!("State resolution done. Compressing state"); - let mut new_room_state = HashSet::new(); - for ((event_type, state_key), event_id) in state { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) - .await; - - let compressed = self - .services - .state_compressor - .compress_state_event(shortstatekey, &event_id) - .await; - - new_room_state.insert(compressed); - } - - Ok(Arc::new(new_room_state)) - } - - // TODO: if we know the prev_events of the incoming event we can avoid the - // request and build the state from a known point and resolve if > 1 prev_event - #[tracing::instrument(skip_all, name = "state")] - pub async fn state_at_incoming_degree_one( - &self, incoming_pdu: &Arc, - ) -> Result>>> { - let prev_event = &*incoming_pdu.prev_events[0]; - let Ok(prev_event_sstatehash) = self - .services - .state_accessor - .pdu_shortstatehash(prev_event) - .await - else { - return Ok(None); - }; - - let Ok(mut state) = self - .services - .state_accessor - .state_full_ids(prev_event_sstatehash) - .await - .log_err() - else { - return Ok(None); - }; - - debug!("Using cached state"); - let prev_pdu = self - .services - .timeline - .get_pdu(prev_event) - .await - .map_err(|e| err!(Database("Could not find prev event, but we know the state: {e:?}")))?; - - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) - .await; - - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } - - debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result"); - - Ok(Some(state)) - } - - #[tracing::instrument(skip_all, name = "state")] - pub async fn state_at_incoming_resolved( - &self, incoming_pdu: &Arc, room_id: &RoomId, room_version_id: &RoomVersionId, - ) -> Result>>> { - debug!("Calculating state at event using state res"); - let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len()); - - let mut okay = true; - for prev_eventid in &incoming_pdu.prev_events { - let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { - okay = false; - break; - }; - - let Ok(sstatehash) = self - .services - .state_accessor - .pdu_shortstatehash(prev_eventid) - .await - else { - okay = false; - break; - }; - - extremity_sstatehashes.insert(sstatehash, prev_event); - } - - if !okay { - return Ok(None); - } - - let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); - let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - for (sstatehash, prev_event) in extremity_sstatehashes { - let Ok(mut leaf_state) = self - .services - .state_accessor - .state_full_ids(sstatehash) - .await - else { - continue; - }; - - if let Some(state_key) = &prev_event.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) - .await; - - let event_id = &prev_event.event_id; - leaf_state.insert(shortstatekey, event_id.clone()); - // Now it's the state after the pdu - } - - let mut state = StateMap::with_capacity(leaf_state.len()); - let mut starting_events = Vec::with_capacity(leaf_state.len()); - for (k, id) in &leaf_state { - if let Ok((ty, st_key)) = self - .services - .short - .get_statekey_from_short(*k) - .await - .log_err() - { - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - state.insert((ty.to_string().into(), st_key), id.clone()); - } - - starting_events.push(id.borrow()); - } - - let auth_chain: HashSet> = self - .services - .auth_chain - .get_event_ids(room_id, &starting_events) - .await? - .into_iter() - .collect(); - - auth_chain_sets.push(auth_chain); - fork_states.push(state); - } - - let lock = self.services.globals.stateres_mutex.lock(); - - let event_fetch = |event_id| self.event_fetch(event_id); - let event_exists = |event_id| self.event_exists(event_id); - let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) - .await - .map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed.")))); - - drop(lock); - - let Ok(new_state) = result else { - return Ok(None); - }; - - new_state - .iter() - .stream() - .then(|((event_type, state_key), event_id)| { - self.services - .short - .get_or_create_shortstatekey(event_type, state_key) - .map(move |shortstatekey| (shortstatekey, event_id.clone())) - }) - .collect() - .map(Some) - .map(Ok) - .await - } - - /// Call /state_ids to find out what the state at this pdu is. We trust the - /// server's response to some extend (sic), but we still do a lot of checks - /// on the events - #[tracing::instrument(skip(self, create_event, room_version_id))] - async fn fetch_state( - &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, - event_id: &EventId, - ) -> Result>>> { - debug!("Fetching state ids"); - let res = self - .services - .sending - .send_synapse_request( - origin, - get_room_state_ids::v1::Request { - room_id: room_id.to_owned(), - event_id: (*event_id).to_owned(), - }, - ) - .await - .inspect_err(|e| warn!("Fetching state for event failed: {e}"))?; - - debug!("Fetching state events"); - let collect = res - .pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(); - - let state_vec = self - .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id) - .boxed() - .await; - - let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); - for (pdu, _) in state_vec { - let state_key = pdu - .state_key - .clone() - .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; - - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) - .await; - - match state.entry(shortstatekey) { - hash_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::bad_database( - "State event's type and state_key combination exists multiple times.", - )) - }, - } - } - - // The original create event must still be in the state - let create_shortstatekey = self - .services - .short - .get_shortstatekey(&StateEventType::RoomCreate, "") - .await?; - - if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { - return Err!(Database("Incoming event refers to wrong create event.")); - } - - Ok(Some(state)) - } - - /// Find the event and auth it. Once the event is validated (steps 1 - 8) - /// it is appended to the outliers Tree. - /// - /// Returns pdu and if we fetched it over federation the raw json. - /// - /// a. Look in the main timeline (pduid_pdu tree) - /// b. Look at outlier pdu tree - /// c. Ask origin server over federation - /// d. TODO: Ask other servers over federation? - pub async fn fetch_and_handle_outliers<'a>( - &self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, - room_version_id: &'a RoomVersionId, - ) -> Vec<(Arc, Option>)> { - let back_off = |id| match self - .services - .globals - .bad_event_ratelimiter - .write() - .expect("locked") - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), - }; - - let mut events_with_auth_events = Vec::with_capacity(events.len()); - for id in events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { - trace!("Found {id} in db"); - events_with_auth_events.push((id, Some(local_pdu), vec![])); - continue; - } - - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); - let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&*next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - info!("Backing off from {next_id}"); - continue; - } - } - - if events_all.contains(&next_id) { - continue; - } - - if self.services.timeline.pdu_exists(&next_id).await { - trace!("Found {next_id} in db"); - continue; - } - - debug!("Fetching {next_id} over federation."); - match self - .services - .sending - .send_federation_request( - origin, - get_event::v1::Request { - event_id: (*next_id).to_owned(), - include_unredacted_content: None, - }, - ) - .await - { - Ok(res) => { - debug!("Got {next_id} over federation"); - let Ok((calculated_event_id, value)) = - pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) - else { - back_off((*next_id).to_owned()); - continue; - }; - - if calculated_event_id != *next_id { - warn!( - "Server didn't return event id we requested: requested: {next_id}, we got \ - {calculated_event_id}. Event: {:?}", - &res.pdu - ); - } - - if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { - for auth_event in auth_events { - if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { - let a: Arc = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } - } - } else { - warn!("Auth event list invalid"); - } - - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - }, - Err(e) => { - debug_error!("Failed to fetch event {next_id}: {e}"); - back_off((*next_id).to_owned()); - }, - } - } - events_with_auth_events.push((id, None, events_in_reverse_order)); - } - - let mut pdus = Vec::with_capacity(events_with_auth_events.len()); - for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Some(local_pdu) = local_pdu { - trace!("Found {id} in db"); - pdus.push((local_pdu.clone(), None)); - } - - for (next_id, value) in events_in_reverse_order.into_iter().rev() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&*next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - debug!("Backing off from {next_id}"); - continue; - } - } - - match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true)) - .await - { - Ok((pdu, json)) => { - if next_id == *id { - pdus.push((pdu, Some(json))); - } - }, - Err(e) => { - warn!("Authentication of event {next_id} failed: {e:?}"); - back_off(next_id.into()); - }, - } - } - } - pdus - } - - #[allow(clippy::type_complexity)] - #[tracing::instrument(skip_all)] - async fn fetch_prev( - &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, - initial_set: Vec>, - ) -> Result<( - Vec>, - HashMap, (Arc, BTreeMap)>, - )> { - let mut graph: HashMap, _> = HashMap::with_capacity(initial_set.len()); - let mut eventid_info = HashMap::new(); - let mut todo_outlier_stack: Vec> = initial_set; - - let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; - - let mut amount = 0; - - while let Some(prev_event_id) = todo_outlier_stack.pop() { - self.services.server.check_running()?; - - if let Some((pdu, mut json_opt)) = self - .fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id) - .boxed() - .await - .pop() - { - check_room_id(room_id, &pdu)?; - - let limit = self.services.server.config.max_fetch_prev_events; - if amount > limit { - debug_warn!("Max prev event limit reached! Limit: {limit}"); - graph.insert(prev_event_id.clone(), HashSet::new()); - continue; - } - - if json_opt.is_none() { - json_opt = self - .services - .outlier - .get_outlier_pdu_json(&prev_event_id) - .await - .ok(); - } - - if let Some(json) = json_opt { - if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { - amount = amount.saturating_add(1); - for prev_prev in &pdu.prev_events { - if !graph.contains_key(prev_prev) { - todo_outlier_stack.push(prev_prev.clone()); - } - } - - graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect()); - } else { - // Time based check failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - - eventid_info.insert(prev_event_id.clone(), (pdu, json)); - } else { - // Get json failed, so this was not fetched over federation - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } else { - // Fetch and handle failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } - - let event_fetch = |event_id| { - let origin_server_ts = eventid_info - .get(&event_id) - .cloned() - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts); - - // This return value is the key used for sorting events, - // events are then sorted by power level, time, - // and lexically by event_id. - future::ok((int!(0), MilliSecondsSinceUnixEpoch(origin_server_ts))) - }; - - let sorted = state_res::lexicographical_topological_sort(&graph, &event_fetch) - .await - .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?; - - Ok((sorted, eventid_info)) - } - - /// Returns Ok if the acl allows the server - #[tracing::instrument(skip_all)] - pub async fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { - let Ok(acl_event_content) = self - .services - .state_accessor - .room_state_get_content(room_id, &StateEventType::RoomServerAcl, "") - .await - .map(|c: RoomServerAclEventContent| c) - .inspect(|acl| trace!("ACL content found: {acl:?}")) - .inspect_err(|e| trace!("No ACL content found: {e:?}")) - else { - return Ok(()); - }; - - if acl_event_content.allow.is_empty() { - warn!("Ignoring broken ACL event (allow key is empty)"); - return Ok(()); - } - - if acl_event_content.is_allowed(server_name) { - trace!("server {server_name} is allowed by ACL"); - Ok(()) - } else { - debug!("Server {server_name} was denied by room ACL in {room_id}"); - Err!(Request(Forbidden("Server was denied by room ACL"))) - } - } - async fn event_exists(&self, event_id: Arc) -> bool { self.services.timeline.pdu_exists(&event_id).await } async fn event_fetch(&self, event_id: Arc) -> Option> { diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 39920219a..42f44deec 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,30 +1,27 @@ -use conduit::{err, pdu::gen_event_id_canonical_json, result::FlatOk, Result}; +use conduit::{err, implement, pdu::gen_event_id_canonical_json, result::FlatOk, Result}; use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; -impl super::Service { - pub async fn parse_incoming_pdu( - &self, pdu: &RawJsonValue, - ) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value = serde_json::from_str::(pdu.get()) - .map_err(|e| err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))))?; +#[implement(super::Service)] +pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + let value = serde_json::from_str::(pdu.get()) + .map_err(|e| err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))))?; - let room_id: OwnedRoomId = value - .get("room_id") - .and_then(CanonicalJsonValue::as_str) - .map(RoomId::parse) - .flat_ok_or(err!(Request(InvalidParam("Invalid room_id in pdu"))))?; + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(CanonicalJsonValue::as_str) + .map(RoomId::parse) + .flat_ok_or(err!(Request(InvalidParam("Invalid room_id in pdu"))))?; - let room_version_id = self - .services - .state - .get_room_version(&room_id) - .await - .map_err(|_| err!("Server is not in room {room_id}"))?; + let room_version_id = self + .services + .state + .get_room_version(&room_id) + .await + .map_err(|_| err!("Server is not in room {room_id}"))?; - let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id) - .map_err(|e| err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))))?; + let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id) + .map_err(|e| err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))))?; - Ok((event_id, value, room_id)) - } + Ok((event_id, value, room_id)) } diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs new file mode 100644 index 000000000..0c9525dd7 --- /dev/null +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -0,0 +1,101 @@ +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use conduit::{debug, err, implement, utils::IterStream, Result}; +use futures::{FutureExt, StreamExt, TryFutureExt}; +use ruma::{ + state_res::{self, StateMap}, + EventId, RoomId, RoomVersionId, +}; + +use crate::rooms::state_compressor::CompressedStateEvent; + +#[implement(super::Service)] +#[tracing::instrument(skip_all, name = "resolve")] +pub async fn resolve_state( + &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, +) -> Result>> { + debug!("Loading current room state ids"); + let current_sstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; + + let current_state_ids = self + .services + .state_accessor + .state_full_ids(current_sstatehash) + .await?; + + let fork_states = [current_state_ids, incoming_state]; + let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); + for state in &fork_states { + let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); + + let auth_chain: HashSet> = self + .services + .auth_chain + .get_event_ids(room_id, &starting_events) + .await? + .into_iter() + .collect(); + + auth_chain_sets.push(auth_chain); + } + + debug!("Loading fork states"); + let fork_states: Vec>> = fork_states + .into_iter() + .stream() + .then(|fork_state| { + fork_state + .into_iter() + .stream() + .filter_map(|(k, id)| { + self.services + .short + .get_statekey_from_short(k) + .map_ok_or_else(|_| None, move |(ty, st_key)| Some(((ty, st_key), id))) + }) + .collect() + }) + .collect() + .boxed() + .await; + + debug!("Resolving state"); + let lock = self.services.globals.stateres_mutex.lock(); + + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?; + + drop(lock); + + debug!("State resolution done. Compressing state"); + let mut new_room_state = HashSet::new(); + for ((event_type, state_key), event_id) in state { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) + .await; + + let compressed = self + .services + .state_compressor + .compress_state_event(shortstatekey, &event_id) + .await; + + new_room_state.insert(compressed); + } + + Ok(Arc::new(new_room_state)) +} diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs new file mode 100644 index 000000000..a200ab568 --- /dev/null +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -0,0 +1,178 @@ +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use conduit::{debug, err, implement, result::LogErr, utils::IterStream, PduEvent, Result}; +use futures::{FutureExt, StreamExt}; +use ruma::{ + state_res::{self, StateMap}, + EventId, RoomId, RoomVersionId, +}; + +// TODO: if we know the prev_events of the incoming event we can avoid the +#[implement(super::Service)] +// request and build the state from a known point and resolve if > 1 prev_event +#[tracing::instrument(skip_all, name = "state")] +pub(super) async fn state_at_incoming_degree_one( + &self, incoming_pdu: &Arc, +) -> Result>>> { + let prev_event = &*incoming_pdu.prev_events[0]; + let Ok(prev_event_sstatehash) = self + .services + .state_accessor + .pdu_shortstatehash(prev_event) + .await + else { + return Ok(None); + }; + + let Ok(mut state) = self + .services + .state_accessor + .state_full_ids(prev_event_sstatehash) + .await + .log_err() + else { + return Ok(None); + }; + + debug!("Using cached state"); + let prev_pdu = self + .services + .timeline + .get_pdu(prev_event) + .await + .map_err(|e| err!(Database("Could not find prev event, but we know the state: {e:?}")))?; + + if let Some(state_key) = &prev_pdu.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .await; + + state.insert(shortstatekey, Arc::from(prev_event)); + // Now it's the state after the pdu + } + + debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result"); + + Ok(Some(state)) +} + +#[implement(super::Service)] +#[tracing::instrument(skip_all, name = "state")] +pub(super) async fn state_at_incoming_resolved( + &self, incoming_pdu: &Arc, room_id: &RoomId, room_version_id: &RoomVersionId, +) -> Result>>> { + debug!("Calculating state at event using state res"); + let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len()); + + let mut okay = true; + for prev_eventid in &incoming_pdu.prev_events { + let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { + okay = false; + break; + }; + + let Ok(sstatehash) = self + .services + .state_accessor + .pdu_shortstatehash(prev_eventid) + .await + else { + okay = false; + break; + }; + + extremity_sstatehashes.insert(sstatehash, prev_event); + } + + if !okay { + return Ok(None); + } + + let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); + let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); + for (sstatehash, prev_event) in extremity_sstatehashes { + let Ok(mut leaf_state) = self + .services + .state_accessor + .state_full_ids(sstatehash) + .await + else { + continue; + }; + + if let Some(state_key) = &prev_event.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) + .await; + + let event_id = &prev_event.event_id; + leaf_state.insert(shortstatekey, event_id.clone()); + // Now it's the state after the pdu + } + + let mut state = StateMap::with_capacity(leaf_state.len()); + let mut starting_events = Vec::with_capacity(leaf_state.len()); + for (k, id) in &leaf_state { + if let Ok((ty, st_key)) = self + .services + .short + .get_statekey_from_short(*k) + .await + .log_err() + { + // FIXME: Undo .to_string().into() when StateMap + // is updated to use StateEventType + state.insert((ty.to_string().into(), st_key), id.clone()); + } + + starting_events.push(id.borrow()); + } + + let auth_chain: HashSet> = self + .services + .auth_chain + .get_event_ids(room_id, &starting_events) + .await? + .into_iter() + .collect(); + + auth_chain_sets.push(auth_chain); + fork_states.push(state); + } + + let lock = self.services.globals.stateres_mutex.lock(); + + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed.")))); + + drop(lock); + + let Ok(new_state) = result else { + return Ok(None); + }; + + new_state + .iter() + .stream() + .then(|((event_type, state_key), event_id)| { + self.services + .short + .get_or_create_shortstatekey(event_type, state_key) + .map(move |shortstatekey| (shortstatekey, event_id.clone())) + }) + .collect() + .map(Some) + .map(Ok) + .await +} diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs new file mode 100644 index 000000000..2a1e46625 --- /dev/null +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -0,0 +1,298 @@ +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, + time::Instant, +}; + +use conduit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result}; +use futures::{future::ready, StreamExt}; +use ruma::{ + api::client::error::ErrorKind, + events::{room::redaction::RoomRedactionEventContent, StateEventType, TimelineEventType}, + state_res::{self, EventTypeExt}, + CanonicalJsonValue, RoomId, RoomVersionId, ServerName, +}; + +use super::{get_room_version_id, to_room_version}; +use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPduId}; + +#[implement(super::Service)] +pub(super) async fn upgrade_outlier_to_timeline_pdu( + &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, + origin: &ServerName, room_id: &RoomId, +) -> Result> { + // Skip the PDU if we already have it as a timeline event + if let Ok(pduid) = self + .services + .timeline + .get_pdu_id(&incoming_pdu.event_id) + .await + { + return Ok(Some(pduid)); + } + + if self + .services + .pdu_metadata + .is_event_soft_failed(&incoming_pdu.event_id) + .await + { + return Err!(Request(InvalidParam("Event has been soft failed"))); + } + + debug!("Upgrading to timeline pdu"); + let timer = Instant::now(); + let room_version_id = get_room_version_id(create_event)?; + + // 10. Fetch missing state and auth chain events by calling /state_ids at + // backwards extremities doing all the checks in this list starting at 1. + // These are not timeline events. + + debug!("Resolving state at event"); + let mut state_at_incoming_event = if incoming_pdu.prev_events.len() == 1 { + self.state_at_incoming_degree_one(&incoming_pdu).await? + } else { + self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id) + .await? + }; + + if state_at_incoming_event.is_none() { + state_at_incoming_event = self + .fetch_state(origin, create_event, room_id, &room_version_id, &incoming_pdu.event_id) + .await?; + } + + let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); + let room_version = to_room_version(&room_version_id); + + debug!("Performing auth check"); + // 11. Check the auth of the event passes based on the state of the event + let state_fetch_state = &state_at_incoming_event; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = state_fetch_state.get(&shortstatekey)?; + self.services.timeline.get_pdu(event_id).await.ok() + }; + + let auth_check = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; + + if !auth_check { + return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); + } + + debug!("Gathering auth events"); + let auth_events = self + .services + .state + .get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + ) + .await?; + + let state_fetch = |k: &'static StateEventType, s: &str| { + let key = k.with_state_key(s); + ready(auth_events.get(&key).cloned()) + }; + + let auth_check = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None, // third-party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; + + // Soft fail check before doing state res + debug!("Performing soft-fail check"); + let soft_fail = { + use RoomVersionId::*; + + !auth_check + || incoming_pdu.kind == TimelineEventType::RoomRedaction + && match room_version_id { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + if let Some(redact_id) = &incoming_pdu.redacts { + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? + } else { + false + } + }, + _ => { + let content: RoomRedactionEventContent = incoming_pdu.get_content()?; + if let Some(redact_id) = &content.redacts { + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? + } else { + false + } + }, + } + }; + + // 13. Use state resolution to find new room state + + // We start looking at current room state now, so lets lock the room + trace!("Locking the room"); + let state_lock = self.services.state.mutex.lock(room_id).await; + + // Now we calculate the set of extremities this room has after the incoming + // event has been applied. We start with the previous extremities (aka leaves) + trace!("Calculating extremities"); + let mut extremities: HashSet<_> = self + .services + .state + .get_forward_extremities(room_id) + .map(ToOwned::to_owned) + .collect() + .await; + + // Remove any forward extremities that are referenced by this incoming event's + // prev_events + trace!( + "Calculated {} extremities; checking against {} prev_events", + extremities.len(), + incoming_pdu.prev_events.len() + ); + for prev_event in &incoming_pdu.prev_events { + extremities.remove(&(**prev_event)); + } + + // Only keep those extremities were not referenced yet + let mut retained = HashSet::new(); + for id in &extremities { + if !self + .services + .pdu_metadata + .is_event_referenced(room_id, id) + .await + { + retained.insert(id.clone()); + } + } + + extremities.retain(|id| retained.contains(id)); + debug!("Retained {} extremities. Compressing state", extremities.len()); + + let mut state_ids_compressed = HashSet::new(); + for (shortstatekey, id) in &state_at_incoming_event { + state_ids_compressed.insert( + self.services + .state_compressor + .compress_state_event(*shortstatekey, id) + .await, + ); + } + + let state_ids_compressed = Arc::new(state_ids_compressed); + + if incoming_pdu.state_key.is_some() { + debug!("Event is a state-event. Deriving new room state"); + + // We also add state after incoming event to the fork states + let mut state_after = state_at_incoming_event.clone(); + if let Some(state_key) = &incoming_pdu.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) + .await; + + let event_id = &incoming_pdu.event_id; + state_after.insert(shortstatekey, event_id.clone()); + } + + let new_room_state = self + .resolve_state(room_id, &room_version_id, state_after) + .await?; + + // Set the new room state to the resolved state + debug!("Forcing new room state"); + let HashSetCompressStateEvent { + shortstatehash, + added, + removed, + } = self + .services + .state_compressor + .save_state(room_id, new_room_state) + .await?; + + self.services + .state + .force_state(room_id, shortstatehash, added, removed, &state_lock) + .await?; + } + + // 14. Check if the event passes auth based on the "current state" of the room, + // if not soft fail it + if soft_fail { + debug!("Soft failing event"); + self.services + .timeline + .append_incoming_pdu( + &incoming_pdu, + val, + extremities.iter().map(|e| (**e).to_owned()).collect(), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .await?; + + // Soft fail, we keep the event as an outlier but don't add it to the timeline + warn!("Event was soft failed: {incoming_pdu:?}"); + self.services + .pdu_metadata + .mark_event_soft_failed(&incoming_pdu.event_id); + + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); + } + + trace!("Appending pdu to timeline"); + extremities.insert(incoming_pdu.event_id.clone().into()); + + // Now that the event has passed all auth it is added into the timeline. + // We use the `state_at_event` instead of `state_after` so we accurately + // represent the state for this event. + let pdu_id = self + .services + .timeline + .append_incoming_pdu( + &incoming_pdu, + val, + extremities.into_iter().collect(), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .await?; + + // Event has passed all auth/stateres checks + drop(state_lock); + debug_info!( + elapsed = ?timer.elapsed(), + "Accepted", + ); + + Ok(pdu_id) +} From 10be3016466076a76ab0e9270dabb80e2acf1afa Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 9 Nov 2024 01:09:09 +0000 Subject: [PATCH 190/245] split large notary requests into batches Signed-off-by: Jason Volk --- src/core/config/mod.rs | 8 +++ src/service/server_keys/acquire.rs | 4 ++ src/service/server_keys/get.rs | 4 +- src/service/server_keys/mod.rs | 2 +- src/service/server_keys/request.rs | 89 +++++++++++++++++++----------- 5 files changed, 71 insertions(+), 36 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 43cca4b8a..cd9c1b38a 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -582,6 +582,12 @@ pub struct Config { #[serde(default)] pub only_query_trusted_key_servers: bool, + /// Maximum number of keys to request in each trusted server query. + /// + /// default: 1024 + #[serde(default = "default_trusted_server_batch_size")] + pub trusted_server_batch_size: usize, + /// max log level for conduwuit. allows debug, info, warn, or error /// see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives /// **Caveat**: @@ -2062,3 +2068,5 @@ fn parallelism_scaled_u32(val: u32) -> u32 { } fn parallelism_scaled(val: usize) -> usize { val.saturating_mul(sys::available_parallelism()) } + +fn default_trusted_server_batch_size() -> usize { 256 } diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs index cdaf28b4a..190b42392 100644 --- a/src/service/server_keys/acquire.rs +++ b/src/service/server_keys/acquire.rs @@ -110,6 +110,10 @@ where {requested_servers} total servers; some events may not be verifiable" ); } + + for (server, key_ids) in missing { + debug_warn!(?server, ?key_ids, "missing"); + } } #[implement(super::Service)] diff --git a/src/service/server_keys/get.rs b/src/service/server_keys/get.rs index 441e33d45..dc4627f7a 100644 --- a/src/service/server_keys/get.rs +++ b/src/service/server_keys/get.rs @@ -89,8 +89,8 @@ pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKe async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result { for notary in self.services.globals.trusted_servers() { if let Ok(server_keys) = self.notary_request(notary, origin).await { - for server_key in &server_keys { - self.add_signing_keys(server_key.clone()).await; + for server_key in server_keys.clone() { + self.add_signing_keys(server_key).await; } for server_key in server_keys { diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index dae45a51c..333970df3 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -7,7 +7,7 @@ mod verify; use std::{collections::BTreeMap, sync::Arc, time::Duration}; -use conduit::{implement, utils::time::timepoint_from_now, Result, Server}; +use conduit::{implement, utils::timepoint_from_now, Result, Server}; use database::{Deserialized, Json, Map}; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, diff --git a/src/service/server_keys/request.rs b/src/service/server_keys/request.rs index 84dd28717..7078f7cd0 100644 --- a/src/service/server_keys/request.rs +++ b/src/service/server_keys/request.rs @@ -1,6 +1,6 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, fmt::Debug}; -use conduit::{implement, Err, Result}; +use conduit::{debug, implement, Err, Result}; use ruma::{ api::federation::discovery::{ get_remote_server_keys, @@ -25,34 +25,57 @@ where minimum_valid_until_ts: Some(self.minimum_valid_ts()), }; - let mut server_keys = RumaBatch::new(); - for (server, key_ids) in batch { - let entry = server_keys.entry(server.into()).or_default(); - for key_id in key_ids { - entry.insert(key_id.into(), criteria.clone()); - } - } + let mut server_keys = batch.fold(RumaBatch::new(), |mut batch, (server, key_ids)| { + batch + .entry(server.into()) + .or_default() + .extend(key_ids.map(|key_id| (key_id.into(), criteria.clone()))); + + batch + }); debug_assert!(!server_keys.is_empty(), "empty batch request to notary"); - let request = Request { - server_keys, - }; - self.services - .sending - .send_federation_request(notary, request) - .await - .map(|response| response.server_keys) - .map(|keys| { - keys.into_iter() - .map(|key| key.deserialize()) - .filter_map(Result::ok) - .collect() - }) + let mut results = Vec::new(); + while let Some(batch) = server_keys + .keys() + .rev() + .take(self.services.server.config.trusted_server_batch_size) + .last() + .cloned() + { + let request = Request { + server_keys: server_keys.split_off(&batch), + }; + + debug!( + ?notary, + ?batch, + remaining = %server_keys.len(), + requesting = ?request.server_keys.keys(), + "notary request" + ); + + let response = self + .services + .sending + .send_synapse_request(notary, request) + .await? + .server_keys + .into_iter() + .map(|key| key.deserialize()) + .filter_map(Result::ok); + + results.extend(response); + } + + Ok(results) } #[implement(super::Service)] -pub async fn notary_request(&self, notary: &ServerName, target: &ServerName) -> Result> { +pub async fn notary_request( + &self, notary: &ServerName, target: &ServerName, +) -> Result + Clone + Debug + Send> { use get_remote_server_keys::v2::Request; let request = Request { @@ -60,17 +83,17 @@ pub async fn notary_request(&self, notary: &ServerName, target: &ServerName) -> minimum_valid_until_ts: self.minimum_valid_ts(), }; - self.services + let response = self + .services .sending .send_federation_request(notary, request) - .await - .map(|response| response.server_keys) - .map(|keys| { - keys.into_iter() - .map(|key| key.deserialize()) - .filter_map(Result::ok) - .collect() - }) + .await? + .server_keys + .into_iter() + .map(|key| key.deserialize()) + .filter_map(Result::ok); + + Ok(response) } #[implement(super::Service)] From 14fce384034c348c6ba35fc946b6cbffaa970f3e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 9 Nov 2024 02:42:09 +0000 Subject: [PATCH 191/245] cork around send_join response processing Signed-off-by: Jason Volk --- src/api/client/membership.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index bf8e5c33b..2906d35bf 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -877,6 +877,7 @@ async fn join_room_by_id_helper_remote( .await; info!("Going through send_join response room_state"); + let cork = services.db.cork_and_flush(); let mut state = HashMap::new(); for result in send_join_response.room_state.state.iter().map(|pdu| { services @@ -902,8 +903,10 @@ async fn join_room_by_id_helper_remote( state.insert(shortstatekey, pdu.event_id.clone()); } } + drop(cork); info!("Going through send_join response auth_chain"); + let cork = services.db.cork_and_flush(); for result in send_join_response.room_state.auth_chain.iter().map(|pdu| { services .server_keys @@ -915,6 +918,7 @@ async fn join_room_by_id_helper_remote( services.rooms.outlier.add_pdu_outlier(&event_id, &value); } + drop(cork); debug!("Running send_join auth check"); let fetch_state = &state; From cc86feded32bb94b5171462bf9ce9c7b1adde04d Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 10 Nov 2024 01:49:16 +0000 Subject: [PATCH 192/245] bump ruma fixes for key type changes Signed-off-by: Jason Volk --- Cargo.lock | 475 +++++++++++++++++++++++++++++++-------- Cargo.toml | 2 +- src/api/client/keys.rs | 19 +- src/api/server/key.rs | 5 +- src/service/users/mod.rs | 20 +- 5 files changed, 402 insertions(+), 119 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f729d3d4a..a1654ff96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,15 +43,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.91" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "arc-swap" @@ -127,7 +127,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -138,7 +138,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -373,7 +373,7 @@ dependencies = [ "regex", "rustc-hash 1.1.0", "shlex", - "syn 2.0.85", + "syn 2.0.87", "which", ] @@ -481,9 +481,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.31" +version = "1.1.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" +checksum = "40545c26d092346d8a8dab71ee48e7685a7a9cba76e634790c215b41a4a7b4cf" dependencies = [ "jobserver", "libc", @@ -569,7 +569,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -713,7 +713,7 @@ dependencies = [ "serde_json", "serde_regex", "serde_yaml", - "thiserror", + "thiserror 1.0.68", "tikv-jemalloc-ctl", "tikv-jemalloc-sys", "tikv-jemallocator", @@ -749,7 +749,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1047,7 +1047,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1074,7 +1074,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1135,6 +1135,17 @@ dependencies = [ "subtle", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "dunce" version = "1.0.5" @@ -1184,7 +1195,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1276,7 +1287,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" dependencies = [ "nonempty", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -1350,7 +1361,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1461,9 +1472,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" [[package]] name = "hdrhistogram" @@ -1537,7 +1548,7 @@ dependencies = [ "ipnet", "once_cell", "rand", - "thiserror", + "thiserror 1.0.68", "tinyvec", "tokio", "tracing", @@ -1560,7 +1571,7 @@ dependencies = [ "rand", "resolv-conf", "smallvec", - "thiserror", + "thiserror 1.0.68", "tokio", "tracing", ] @@ -1616,7 +1627,7 @@ dependencies = [ "markup5ever", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1753,6 +1764,124 @@ dependencies = [ "tracing", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "idna" version = "0.4.0" @@ -1765,19 +1894,30 @@ dependencies = [ [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] name = "image" -version = "0.25.4" +version = "0.25.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc144d44a31d753b02ce64093d532f55ff8dc4ebf2ffb8a63c0dda691385acae" +checksum = "cd6f44aed642f18953a158afeb30206f4d50da59fbc66ecb53c66488de73563b" dependencies = [ "bytemuck", "byteorder-lite", @@ -1817,7 +1957,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "serde", ] @@ -1980,7 +2120,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1997,9 +2137,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.161" +version = "0.2.162" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" [[package]] name = "libloading" @@ -2034,6 +2174,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "lock_api" version = "0.4.12" @@ -2356,7 +2502,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.68", "urlencoding", ] @@ -2399,10 +2545,10 @@ dependencies = [ "glob", "once_cell", "opentelemetry", - "ordered-float 4.4.0", + "ordered-float 4.5.0", "percent-encoding", "rand", - "thiserror", + "thiserror 1.0.68", "tokio", "tokio-stream", ] @@ -2418,9 +2564,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "4.4.0" +version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83e7ccb95e240b7c9506a3d544f10d935e142cc90b0a1d56954fb44d89ad6b97" +checksum = "c65ee1f9701bf938026630b455d5315f490640234259037edb259798b3bcf85e" dependencies = [ "num-traits", ] @@ -2502,7 +2648,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2595,7 +2741,7 @@ checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2667,7 +2813,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2696,7 +2842,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "version_check", "yansi", ] @@ -2721,7 +2867,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2776,7 +2922,7 @@ dependencies = [ "rustc-hash 2.0.0", "rustls 0.23.16", "socket2", - "thiserror", + "thiserror 1.0.68", "tokio", "tracing", ] @@ -2793,17 +2939,18 @@ dependencies = [ "rustc-hash 2.0.0", "rustls 0.23.16", "slab", - "thiserror", + "thiserror 1.0.68", "tinyvec", "tracing", ] [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" dependencies = [ + "cfg_aliases", "libc", "once_cell", "socket2", @@ -2980,7 +3127,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "assign", "js_int", @@ -3002,7 +3149,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "js_int", "ruma-common", @@ -3014,7 +3161,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "as_variant", "assign", @@ -3029,7 +3176,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "thiserror", + "thiserror 2.0.1", "url", "web-time 1.1.0", ] @@ -3037,7 +3184,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "as_variant", "base64 0.22.1", @@ -3055,7 +3202,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "thiserror", + "thiserror 2.0.1", "time", "tracing", "url", @@ -3067,7 +3214,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3081,7 +3228,7 @@ dependencies = [ "ruma-macros", "serde", "serde_json", - "thiserror", + "thiserror 2.0.1", "tracing", "url", "web-time 1.1.0", @@ -3091,7 +3238,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "bytes", "http", @@ -3109,16 +3256,16 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "js_int", - "thiserror", + "thiserror 2.0.1", ] [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "js_int", "ruma-common", @@ -3128,7 +3275,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "cfg-if", "once_cell", @@ -3137,14 +3284,14 @@ dependencies = [ "quote", "ruma-identifiers-validation", "serde", - "syn 2.0.85", + "syn 2.0.87", "toml", ] [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "js_int", "ruma-common", @@ -3156,20 +3303,20 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "headers", "http", "http-auth", "ruma-common", - "thiserror", + "thiserror 2.0.1", "tracing", ] [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3179,13 +3326,13 @@ dependencies = [ "serde_json", "sha2", "subslice", - "thiserror", + "thiserror 2.0.1", ] [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a#dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "futures-util", "itertools 0.13.0", @@ -3194,7 +3341,7 @@ dependencies = [ "ruma-events", "serde", "serde_json", - "thiserror", + "thiserror 2.0.1", "tracing", ] @@ -3261,9 +3408,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.38" +version = "0.38.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" +checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" dependencies = [ "bitflags 2.6.0", "errno", @@ -3358,7 +3505,7 @@ dependencies = [ "futures-util", "pin-project", "thingbuf", - "thiserror", + "thiserror 1.0.68", "unicode-segmentation", "unicode-width", ] @@ -3415,9 +3562,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" dependencies = [ "core-foundation-sys", "libc", @@ -3558,7 +3705,7 @@ dependencies = [ "rand", "serde", "serde_json", - "thiserror", + "thiserror 1.0.68", "time", "url", "uuid", @@ -3581,7 +3728,7 @@ checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -3764,7 +3911,7 @@ checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" dependencies = [ "num-bigint", "num-traits", - "thiserror", + "thiserror 1.0.68", "time", ] @@ -3815,6 +3962,12 @@ dependencies = [ "der", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "strict" version = "0.2.0" @@ -3875,9 +4028,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.85" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -3899,6 +4052,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "tendril" version = "0.4.3" @@ -3922,7 +4086,7 @@ dependencies = [ "lazy-regex", "minimad", "serde", - "thiserror", + "thiserror 1.0.68", "unicode-width", ] @@ -3938,22 +4102,42 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.65" +version = "1.0.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" +dependencies = [ + "thiserror-impl 1.0.68", +] + +[[package]] +name = "thiserror" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "07c1e40dd48a282ae8edc36c732cbc219144b87fb6a4c7316d611c6b1f06ec0c" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.1", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "874aa7e446f1da8d9c3a5c95b1c5eb41d800045252121dc7f8e0ba370cee55f5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", ] [[package]] @@ -4047,6 +4231,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -4064,9 +4258,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.41.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -4088,7 +4282,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4122,7 +4316,7 @@ checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" dependencies = [ "either", "futures-util", - "thiserror", + "thiserror 1.0.68", "tokio", ] @@ -4302,7 +4496,7 @@ source = "git+https://github.com/girlbossceo/tracing?rev=4d78a14a5e03f539b8c6b47 dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4483,12 +4677,12 @@ dependencies = [ [[package]] name = "url" -version = "2.5.2" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", - "idna 0.5.0", + "idna 1.0.3", "percent-encoding", "serde", ] @@ -4505,6 +4699,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "uuid" version = "1.11.0" @@ -4570,7 +4776,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "wasm-bindgen-shared", ] @@ -4604,7 +4810,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4934,6 +5140,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "xml5ever" version = "0.18.1" @@ -4951,6 +5169,30 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -4969,7 +5211,28 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", +] + +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "synstructure", ] [[package]] @@ -4978,6 +5241,28 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "zstd" version = "0.13.2" diff --git a/Cargo.toml b/Cargo.toml index 3ac1556c6..5ea6b4e09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -316,7 +316,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "dd8b13ed2fa2ec4d9fe5c6fbb18e701ac4d4d08a" +rev = "67ffedabbf43e1ff6934df0fbf770b21e101406f" features = [ "compat", "rand", diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 44d9164c9..53ec12f92 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -16,7 +16,7 @@ use ruma::{ federation, }, serde::Raw, - DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, + OneTimeKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; @@ -36,13 +36,12 @@ use crate::{ pub(crate) async fn upload_keys_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let (sender_user, sender_device) = body.sender(); - for (key_key, key_value) in &body.one_time_keys { + for (key_id, one_time_key) in &body.one_time_keys { services .users - .add_one_time_key(sender_user, sender_device, key_key, key_value) + .add_one_time_key(sender_user, sender_device, key_id, one_time_key) .await?; } @@ -400,16 +399,16 @@ where while let Some((server, response)) = futures.next().await { if let Ok(Ok(response)) = response { - for (user, masterkey) in response.master_keys { - let (master_key_id, mut master_key) = parse_master_key(&user, &masterkey)?; + for (user, master_key) in response.master_keys { + let (master_key_id, mut master_key) = parse_master_key(&user, &master_key)?; if let Ok(our_master_key) = services .users .get_key(&master_key_id, sender_user, &user, &allowed_signatures) .await { - let (_, our_master_key) = parse_master_key(&user, &our_master_key)?; - master_key.signatures.extend(our_master_key.signatures); + let (_, mut our_master_key) = parse_master_key(&user, &our_master_key)?; + master_key.signatures.append(&mut our_master_key.signatures); } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); @@ -467,7 +466,7 @@ fn add_unsigned_device_display_name( } pub(crate) async fn claim_keys_helper( - services: &Services, one_time_keys_input: &BTreeMap>, + services: &Services, one_time_keys_input: &BTreeMap>, ) -> Result { let mut one_time_keys = BTreeMap::new(); diff --git a/src/api/server/key.rs b/src/api/server/key.rs index 5284593d2..37fffa9fb 100644 --- a/src/api/server/key.rs +++ b/src/api/server/key.rs @@ -1,5 +1,4 @@ use std::{ - collections::BTreeMap, mem::take, time::{Duration, SystemTime}, }; @@ -12,7 +11,7 @@ use ruma::{ OutgoingResponse, }, serde::Raw, - MilliSecondsSinceUnixEpoch, + MilliSecondsSinceUnixEpoch, Signatures, }; /// # `GET /_matrix/key/v2/server` @@ -42,7 +41,7 @@ pub(crate) async fn get_server_keys_route(State(services): State) old_verify_keys, server_name: server_name.to_owned(), valid_until_ts: valid_until_ts(), - signatures: BTreeMap::new(), + signatures: Signatures::new(), }; let server_key = Raw::new(&server_key)?; diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index b9183e128..1f8c56dfa 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -12,8 +12,8 @@ use ruma::{ encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, - OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId, + DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedDeviceId, + OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId, }; use serde_json::json; @@ -341,9 +341,9 @@ impl Service { } pub async fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, + &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &KeyId, one_time_key_value: &Raw, - ) -> Result<()> { + ) -> Result { // All devices have metadata // Only existing devices should be able to call this, but we shouldn't assert // either... @@ -388,8 +388,8 @@ impl Service { } pub async fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result<(OwnedDeviceKeyId, Raw)> { + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &OneTimeKeyAlgorithm, + ) -> Result<(OwnedKeyId, Raw)> { let count = self.services.globals.next_count()?.to_be_bytes(); self.db.userid_lastonetimekeyupdate.insert(user_id, count); @@ -433,23 +433,23 @@ impl Service { pub async fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, - ) -> BTreeMap { + ) -> BTreeMap { type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); - let mut algorithm_counts = BTreeMap::::new(); + let mut algorithm_counts = BTreeMap::::new(); let query = (user_id, device_id); self.db .onetimekeyid_onetimekeys .stream_prefix(&query) .ignore_err() .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { - let device_key_id: &DeviceKeyId = device_key_id + let one_time_key_id: &OneTimeKeyId = device_key_id .as_str() .try_into() .expect("Invalid DeviceKeyID in database"); let count: &mut UInt = algorithm_counts - .entry(device_key_id.algorithm()) + .entry(one_time_key_id.algorithm()) .or_default(); *count = count.saturating_add(1_u32.into()); From 5e74391c6c94e2843f6cf18aaf0b10e2a613690c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 10 Nov 2024 02:29:45 +0000 Subject: [PATCH 193/245] fix config generator macro matchers Signed-off-by: Jason Volk --- src/macros/config.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/macros/config.rs b/src/macros/config.rs index 6ccdb73cd..d7f115359 100644 --- a/src/macros/config.rs +++ b/src/macros/config.rs @@ -168,7 +168,7 @@ fn get_default(field: &Field) -> Option { .segments .iter() .next() - .is_none_or(|s| s.ident == "serde") + .is_none_or(|s| s.ident != "serde") { continue; } @@ -218,7 +218,7 @@ fn get_doc_default(field: &Field) -> Option { continue; }; - if path.segments.iter().next().is_none_or(|s| s.ident == "doc") { + if path.segments.iter().next().is_none_or(|s| s.ident != "doc") { continue; } @@ -261,7 +261,7 @@ fn get_doc_comment(field: &Field) -> Option { continue; }; - if path.segments.iter().next().is_none_or(|s| s.ident == "doc") { + if path.segments.iter().next().is_none_or(|s| s.ident != "doc") { continue; } From 7e087bb93c316fb52ab1b0dad77530eaa6608dfa Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 10 Nov 2024 03:25:57 +0000 Subject: [PATCH 194/245] Fixes for CI Signed-off-by: Jason Volk --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4d2995145..f59c50485 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,7 +65,7 @@ permissions: jobs: tests: name: Test - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main @@ -231,7 +231,7 @@ jobs: build: name: Build - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 needs: tests strategy: matrix: @@ -245,7 +245,7 @@ jobs: - name: Sync repository uses: actions/checkout@v4 - - uses: nixbuild/nix-quick-install-action@v28 + - uses: nixbuild/nix-quick-install-action@master - name: Restore and cache Nix store uses: nix-community/cache-nix-action@v5.1.0 @@ -508,7 +508,7 @@ jobs: docker: name: Docker publish - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 needs: build if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main' || (github.event.pull_request.draft != true)) && (vars.DOCKER_USERNAME != '') && (vars.GITLAB_USERNAME != '') && github.event.pull_request.user.login != 'renovate[bot]' env: From f290d1a9c850008ce932680c91ac5a039d23c9f7 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 10 Nov 2024 08:39:30 +0000 Subject: [PATCH 195/245] prevent retry for missing keys later in join process Signed-off-by: Jason Volk --- src/api/client/membership.rs | 79 ++++++++++++++++++------------- src/service/server_keys/mod.rs | 28 +++++++++-- src/service/server_keys/verify.rs | 20 ++++++++ 3 files changed, 91 insertions(+), 36 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 2906d35bf..97aa1c691 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -878,46 +878,59 @@ async fn join_room_by_id_helper_remote( info!("Going through send_join response room_state"); let cork = services.db.cork_and_flush(); - let mut state = HashMap::new(); - for result in send_join_response.room_state.state.iter().map(|pdu| { - services - .server_keys - .validate_and_add_event_id(pdu, &room_version_id) - }) { - let Ok((event_id, value)) = result.await else { - continue; - }; + let state = send_join_response + .room_state + .state + .iter() + .stream() + .then(|pdu| { + services + .server_keys + .validate_and_add_event_id_no_fetch(pdu, &room_version_id) + }) + .ready_filter_map(Result::ok) + .fold(HashMap::new(), |mut state, (event_id, value)| async move { + let pdu = match PduEvent::from_id_val(&event_id, value.clone()) { + Ok(pdu) => pdu, + Err(e) => { + debug_warn!("Invalid PDU in send_join response: {e:?}: {value:#?}"); + return state; + }, + }; - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - debug_warn!("Invalid PDU in send_join response: {value:#?}"); - err!(BadServerResponse("Invalid PDU in send_join response: {e:?}")) - })?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); + if let Some(state_key) = &pdu.state_key { + let shortstatekey = services + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; + + state.insert(shortstatekey, pdu.event_id.clone()); + } + + state + }) + .await; - services.rooms.outlier.add_pdu_outlier(&event_id, &value); - if let Some(state_key) = &pdu.state_key { - let shortstatekey = services - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) - .await; - state.insert(shortstatekey, pdu.event_id.clone()); - } - } drop(cork); info!("Going through send_join response auth_chain"); let cork = services.db.cork_and_flush(); - for result in send_join_response.room_state.auth_chain.iter().map(|pdu| { - services - .server_keys - .validate_and_add_event_id(pdu, &room_version_id) - }) { - let Ok((event_id, value)) = result.await else { - continue; - }; + send_join_response + .room_state + .auth_chain + .iter() + .stream() + .then(|pdu| { + services + .server_keys + .validate_and_add_event_id_no_fetch(pdu, &room_version_id) + }) + .ready_filter_map(Result::ok) + .ready_for_each(|(event_id, value)| services.rooms.outlier.add_pdu_outlier(&event_id, &value)) + .await; - services.rooms.outlier.add_pdu_outlier(&event_id, &value); - } drop(cork); debug!("Running send_join auth check"); diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index 333970df3..08bcefb63 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -7,13 +7,19 @@ mod verify; use std::{collections::BTreeMap, sync::Arc, time::Duration}; -use conduit::{implement, utils::timepoint_from_now, Result, Server}; +use conduit::{ + implement, + utils::{timepoint_from_now, IterStream}, + Result, Server, +}; use database::{Deserialized, Json, Map}; +use futures::StreamExt; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, serde::Raw, signatures::{Ed25519KeyPair, PublicKeyMap, PublicKeySet}, - MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, + CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, RoomVersionId, ServerName, + ServerSigningKeyId, }; use serde_json::value::RawValue as RawJsonValue; @@ -107,7 +113,23 @@ async fn add_signing_keys(&self, new_keys: ServerSigningKeys) { } #[implement(Service)] -async fn verify_key_exists(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> bool { +pub async fn required_keys_exist(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> bool { + use ruma::signatures::required_keys; + + let Ok(required_keys) = required_keys(object, version) else { + return false; + }; + + required_keys + .iter() + .flat_map(|(server, key_ids)| key_ids.iter().map(move |key_id| (server, key_id))) + .stream() + .all(|(server, key_id)| self.verify_key_exists(server, key_id)) + .await +} + +#[implement(Service)] +pub async fn verify_key_exists(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> bool { type KeysMap<'a> = BTreeMap<&'a ServerSigningKeyId, &'a RawJsonValue>; let Ok(keys) = self diff --git a/src/service/server_keys/verify.rs b/src/service/server_keys/verify.rs index ad20fec7f..c836e324a 100644 --- a/src/service/server_keys/verify.rs +++ b/src/service/server_keys/verify.rs @@ -16,6 +16,26 @@ pub async fn validate_and_add_event_id( Ok((event_id, value)) } +#[implement(super::Service)] +pub async fn validate_and_add_event_id_no_fetch( + &self, pdu: &RawJsonValue, room_version: &RoomVersionId, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?; + if !self.required_keys_exist(&value, room_version).await { + return Err!(BadServerResponse(debug_warn!( + "Event {event_id} cannot be verified: missing keys." + ))); + } + + if let Err(e) = self.verify_event(&value, Some(room_version)).await { + return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}"))); + } + + value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into())); + + Ok((event_id, value)) +} + #[implement(super::Service)] pub async fn verify_event( &self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>, From 1efc52c4401f3237124495c7120746a8f7aa4909 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 10 Nov 2024 11:09:48 +0000 Subject: [PATCH 196/245] increase logging during server keys acquire Signed-off-by: Jason Volk --- src/service/server_keys/acquire.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs index 190b42392..1080d79ef 100644 --- a/src/service/server_keys/acquire.rs +++ b/src/service/server_keys/acquire.rs @@ -4,7 +4,7 @@ use std::{ time::Duration, }; -use conduit::{debug, debug_error, debug_warn, error, implement, result::FlatOk, trace, warn}; +use conduit::{debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn}; use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName, @@ -69,7 +69,7 @@ where return; } - debug!("missing {missing_keys} keys for {missing_servers} servers locally"); + info!("{missing_keys} keys for {missing_servers} servers will be acquired"); if notary_first_always || notary_first_on_join { missing = self.acquire_notary(missing.into_iter()).await; @@ -79,7 +79,7 @@ where return; } - debug_warn!("missing {missing_keys} keys for {missing_servers} servers from all notaries first"); + warn!("missing {missing_keys} keys for {missing_servers} servers from all notaries first"); } if !notary_only { @@ -107,7 +107,7 @@ where if missing_keys > 0 { warn!( "did not obtain {missing_keys} keys for {missing_servers} servers out of {requested_keys} total keys for \ - {requested_servers} total servers; some events may not be verifiable" + {requested_servers} total servers." ); } From 24a5ecb6b4dedf39184e9c38282ab94db1b12d5d Mon Sep 17 00:00:00 2001 From: OverPhoenix Date: Sun, 10 Nov 2024 21:45:37 +0000 Subject: [PATCH 197/245] fix incorrect user id for non-admin invites checking --- src/api/client/membership.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 97aa1c691..bde8dee85 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1306,7 +1306,7 @@ pub(crate) async fn invite_helper( services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option, is_direct: bool, ) -> Result<()> { - if !services.users.is_admin(user_id).await && services.globals.block_non_admin_invites() { + if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() { info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); return Err(Error::BadRequest( ErrorKind::forbidden(), From 08a4e931a0d5353edc01716e371a489f2c14dba3 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 11 Nov 2024 20:12:20 +0000 Subject: [PATCH 198/245] supplement a from_str for FmtSpan Signed-off-by: Jason Volk --- src/core/log/fmt_span.rs | 17 +++++++++++++++++ src/core/log/mod.rs | 1 + 2 files changed, 18 insertions(+) create mode 100644 src/core/log/fmt_span.rs diff --git a/src/core/log/fmt_span.rs b/src/core/log/fmt_span.rs new file mode 100644 index 000000000..5a340d0fa --- /dev/null +++ b/src/core/log/fmt_span.rs @@ -0,0 +1,17 @@ +use tracing_subscriber::fmt::format::FmtSpan; + +use crate::Result; + +#[inline] +pub fn from_str(str: &str) -> Result { + match str.to_uppercase().as_str() { + "ENTER" => Ok(FmtSpan::ENTER), + "EXIT" => Ok(FmtSpan::EXIT), + "NEW" => Ok(FmtSpan::NEW), + "CLOSE" => Ok(FmtSpan::CLOSE), + "ACTIVE" => Ok(FmtSpan::ACTIVE), + "FULL" => Ok(FmtSpan::FULL), + "NONE" => Ok(FmtSpan::NONE), + _ => Err(FmtSpan::NONE), + } +} diff --git a/src/core/log/mod.rs b/src/core/log/mod.rs index 1cba236f0..1c415c6a0 100644 --- a/src/core/log/mod.rs +++ b/src/core/log/mod.rs @@ -1,6 +1,7 @@ pub mod capture; pub mod color; pub mod fmt; +pub mod fmt_span; mod reload; mod suppress; From 9790a6edc992d24490e19161394c3041e137331d Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 11 Nov 2024 20:33:56 +0000 Subject: [PATCH 199/245] add unwrap_or_err to result Signed-off-by: Jason Volk --- src/core/utils/result.rs | 2 ++ src/core/utils/result/unwrap_or_err.rs | 15 +++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 src/core/utils/result/unwrap_or_err.rs diff --git a/src/core/utils/result.rs b/src/core/utils/result.rs index fb1b7b959..6b11ea66f 100644 --- a/src/core/utils/result.rs +++ b/src/core/utils/result.rs @@ -7,10 +7,12 @@ mod log_err; mod map_expect; mod not_found; mod unwrap_infallible; +mod unwrap_or_err; pub use self::{ debug_inspect::DebugInspect, filter::Filter, flat_ok::FlatOk, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, log_err::LogErr, map_expect::MapExpect, not_found::NotFound, unwrap_infallible::UnwrapInfallible, + unwrap_or_err::UnwrapOrErr, }; pub type Result = std::result::Result; diff --git a/src/core/utils/result/unwrap_or_err.rs b/src/core/utils/result/unwrap_or_err.rs new file mode 100644 index 000000000..69901958f --- /dev/null +++ b/src/core/utils/result/unwrap_or_err.rs @@ -0,0 +1,15 @@ +use std::convert::identity; + +use super::Result; + +/// Returns the Ok value or the Err value. Available when the Ok and Err types +/// are the same. This is a way to default the result using the specific Err +/// value rather than unwrap_or_default() using Ok's default. +pub trait UnwrapOrErr { + fn unwrap_or_err(self) -> T; +} + +impl UnwrapOrErr for Result { + #[inline] + fn unwrap_or_err(self) -> T { self.unwrap_or_else(identity::) } +} From e2afaa9f039d26b85bcd518013aa6bb80ce11866 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 11 Nov 2024 20:49:25 +0000 Subject: [PATCH 200/245] add config item for with_span_events Signed-off-by: Jason Volk --- src/core/config/mod.rs | 9 +++++++++ src/main/tracing.rs | 8 ++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index cd9c1b38a..eddab2fe7 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -606,6 +606,12 @@ pub struct Config { #[serde(default = "true_fn", alias = "log_colours")] pub log_colors: bool, + /// configures the span events which will be outputted with the log + /// + /// default: "none" + #[serde(default = "default_log_span_events")] + pub log_span_events: String, + /// OpenID token expiration/TTL in seconds /// /// These are the OpenID tokens that are primarily used for Matrix account @@ -1958,6 +1964,9 @@ pub fn default_log() -> String { .to_owned() } +#[must_use] +pub fn default_log_span_events() -> String { "none".into() } + fn default_notification_push_path() -> String { "/_matrix/push/v1/notify".to_owned() } fn default_openid_token_ttl() -> u64 { 60 * 60 } diff --git a/src/main/tracing.rs b/src/main/tracing.rs index 9b4ad659d..c28fef6b8 100644 --- a/src/main/tracing.rs +++ b/src/main/tracing.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use conduit::{ config::Config, debug_warn, err, - log::{capture, LogLevelReloadHandles}, + log::{capture, fmt_span, LogLevelReloadHandles}, + result::UnwrapOrErr, Result, }; use tracing_subscriber::{layer::SubscriberExt, reload, EnvFilter, Layer, Registry}; @@ -18,7 +19,10 @@ pub(crate) fn init(config: &Config) -> Result<(LogLevelReloadHandles, TracingFla let reload_handles = LogLevelReloadHandles::default(); let console_filter = EnvFilter::try_new(&config.log).map_err(|e| err!(Config("log", "{e}.")))?; - let console_layer = tracing_subscriber::fmt::Layer::new().with_ansi(config.log_colors); + let console_span_events = fmt_span::from_str(&config.log_span_events).unwrap_or_err(); + let console_layer = tracing_subscriber::fmt::Layer::new() + .with_ansi(config.log_colors) + .with_span_events(console_span_events); let (console_reload_filter, console_reload_handle) = reload::Layer::new(console_filter.clone()); reload_handles.add("console", Box::new(console_reload_handle)); From 61174dd0d3632f551735bea9c8ea22c0bf218427 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 11 Nov 2024 21:27:40 +0000 Subject: [PATCH 201/245] check if lazyset already contains user prior to querying Signed-off-by: Jason Volk --- src/api/client/message.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index e8306de9f..cc6365113 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -192,6 +192,10 @@ pub(crate) async fn update_lazy( return lazy; } + if lazy.contains(event.sender()) { + return lazy; + } + if !services .rooms .lazy_loading From 396233304328c75d1271465f28f55e4121e956b4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 11 Nov 2024 05:00:29 +0000 Subject: [PATCH 202/245] partially revert e507c3130673099692143a59adc30a414ef6ca54 Signed-off-by: Jason Volk --- src/api/client/context.rs | 6 ++---- src/api/client/message.rs | 5 +---- src/api/client/relations.rs | 1 - src/api/client/threads.rs | 1 - src/api/server/backfill.rs | 2 +- src/service/rooms/pdu_metadata/data.rs | 2 +- src/service/rooms/threads/mod.rs | 2 +- src/service/rooms/timeline/data.rs | 12 +++++++----- 8 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/api/client/context.rs b/src/api/client/context.rs index f5f981ba0..4359ae121 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -82,7 +82,7 @@ pub(crate) async fn get_context_route( let events_before: Vec<_> = services .rooms .timeline - .pdus_rev(Some(sender_user), room_id, Some(base_token.saturating_sub(1))) + .pdus_rev(Some(sender_user), room_id, Some(base_token)) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) @@ -94,7 +94,7 @@ pub(crate) async fn get_context_route( let events_after: Vec<_> = services .rooms .timeline - .pdus(Some(sender_user), room_id, Some(base_token.saturating_add(1))) + .pdus(Some(sender_user), room_id, Some(base_token)) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) @@ -169,14 +169,12 @@ pub(crate) async fn get_context_route( start: events_before .last() .map(at!(0)) - .map(|count| count.saturating_sub(1)) .as_ref() .map(ToString::to_string), end: events_after .last() .map(at!(0)) - .map(|count| count.saturating_add(1)) .as_ref() .map(ToString::to_string), diff --git a/src/api/client/message.rs b/src/api/client/message.rs index cc6365113..88453de0c 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -138,10 +138,7 @@ pub(crate) async fn get_message_events_route( let start_token = events.first().map(at!(0)).unwrap_or(from); - let next_token = events - .last() - .map(at!(0)) - .map(|count| count.saturating_inc(body.dir)); + let next_token = events.last().map(at!(0)); if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index ee62dbfc9..902e6be60 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -150,7 +150,6 @@ async fn paginate_relations_with_filter( Direction::Backward => events.first(), } .map(at!(0)) - .map(|count| count.saturating_inc(dir)) .as_ref() .map(ToString::to_string); diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 8d4e399bb..906f779da 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -46,7 +46,6 @@ pub(crate) async fn get_threads_route( .last() .filter(|_| threads.len() >= limit) .map(at!(0)) - .map(|count| count.saturating_sub(1)) .as_ref() .map(ToString::to_string), diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 2858d9fda..b0bd48e80 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -55,7 +55,7 @@ pub(crate) async fn get_backfill_route( pdus: services .rooms .timeline - .pdus_rev(None, &body.room_id, Some(from)) + .pdus_rev(None, &body.room_id, Some(from.saturating_add(1))) .await? .take(limit) .filter_map(|(_, pdu)| async move { diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index f3e1ced8b..b06e988e8 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -57,7 +57,7 @@ impl Data { ) -> impl Stream + Send + '_ { let mut current = ArrayVec::::new(); current.extend(target.to_be_bytes()); - current.extend(from.into_unsigned().to_be_bytes()); + current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes()); let current = current.as_slice(); match dir { Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(), diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index fcc629e1c..5821f2795 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -132,7 +132,7 @@ impl Service { let current: RawPduId = PduId { shortroomid, - shorteventid, + shorteventid: shorteventid.saturating_sub(1), } .into(); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 7f1873ab0..22a6c1d0d 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -13,7 +13,7 @@ use conduit::{ }; use database::{Database, Deserialized, Json, KeyVal, Map}; use futures::{Stream, StreamExt}; -use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use ruma::{api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; use super::{PduId, RawPduId}; @@ -205,7 +205,9 @@ impl Data { pub(super) async fn pdus_rev<'a>( &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { - let current = self.count_to_id(room_id, until).await?; + let current = self + .count_to_id(room_id, until, Direction::Backward) + .await?; let prefix = current.shortroomid(); let stream = self .pduid_pdu @@ -220,7 +222,7 @@ impl Data { pub(super) async fn pdus<'a>( &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { - let current = self.count_to_id(room_id, from).await?; + let current = self.count_to_id(room_id, from, Direction::Forward).await?; let prefix = current.shortroomid(); let stream = self .pduid_pdu @@ -267,7 +269,7 @@ impl Data { } } - async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount) -> Result { + async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount, dir: Direction) -> Result { let shortroomid: ShortRoomId = self .services .short @@ -278,7 +280,7 @@ impl Data { // +1 so we don't send the base event let pdu_id = PduId { shortroomid, - shorteventid, + shorteventid: shorteventid.saturating_inc(dir), }; Ok(pdu_id.into()) From 999d731a65fe8f1313d6fe63d5139ee9f357a820 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 11 Nov 2024 22:18:14 +0000 Subject: [PATCH 203/245] move err macro visitor out-of-line; reduce codegen Signed-off-by: Jason Volk --- src/core/error/err.rs | 72 +++++++++++++++++++++++-------------------- src/core/error/mod.rs | 3 +- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/core/error/err.rs b/src/core/error/err.rs index baeb992d2..a24441e00 100644 --- a/src/core/error/err.rs +++ b/src/core/error/err.rs @@ -111,12 +111,8 @@ macro_rules! err { #[macro_export] macro_rules! err_log { ($out:ident, $level:ident, $($fields:tt)+) => {{ - use std::{fmt, fmt::Write}; - use $crate::tracing::{ - callsite, callsite2, level_enabled, metadata, valueset, Callsite, Event, __macro_support, - __tracing_log, - field::{Field, ValueSet, Visit}, + callsite, callsite2, metadata, valueset, Callsite, Level, }; @@ -134,34 +130,7 @@ macro_rules! err_log { fields: $($fields)+, }; - let visit = &mut |vs: ValueSet<'_>| { - struct Visitor<'a>(&'a mut String); - impl Visit for Visitor<'_> { - #[inline] - fn record_debug(&mut self, field: &Field, val: &dyn fmt::Debug) { - if field.name() == "message" { - write!(self.0, "{:?}", val).expect("stream error"); - } else { - write!(self.0, " {}={:?}", field.name(), val).expect("stream error"); - } - } - } - - let meta = __CALLSITE.metadata(); - let enabled = level_enabled!(LEVEL) && { - let interest = __CALLSITE.interest(); - !interest.is_never() && __macro_support::__is_enabled(meta, interest) - }; - - if enabled { - Event::dispatch(meta, &vs); - } - - __tracing_log!(LEVEL, __CALLSITE, &vs); - vs.record(&mut Visitor(&mut $out)); - }; - - (visit)(valueset!(__CALLSITE.metadata().fields(), $($fields)+)); + ($crate::error::visit)(&mut $out, LEVEL, &__CALLSITE, &mut valueset!(__CALLSITE.metadata().fields(), $($fields)+)); ($out).into() }} } @@ -192,3 +161,40 @@ macro_rules! err_lev { $crate::tracing::Level::ERROR }; } + +use std::{fmt, fmt::Write}; + +use tracing::{ + level_enabled, Callsite, Event, __macro_support, __tracing_log, + callsite::DefaultCallsite, + field::{Field, ValueSet, Visit}, + Level, +}; + +struct Visitor<'a>(&'a mut String); + +impl Visit for Visitor<'_> { + #[inline] + fn record_debug(&mut self, field: &Field, val: &dyn fmt::Debug) { + if field.name() == "message" { + write!(self.0, "{val:?}").expect("stream error"); + } else { + write!(self.0, " {}={val:?}", field.name()).expect("stream error"); + } + } +} + +pub fn visit(out: &mut String, level: Level, __callsite: &'static DefaultCallsite, vs: &mut ValueSet<'_>) { + let meta = __callsite.metadata(); + let enabled = level_enabled!(level) && { + let interest = __callsite.interest(); + !interest.is_never() && __macro_support::__is_enabled(meta, interest) + }; + + if enabled { + Event::dispatch(meta, vs); + } + + __tracing_log!(level, __callsite, vs); + vs.record(&mut Visitor(out)); +} diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 302d0f87e..35bf98009 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -6,8 +6,7 @@ mod serde; use std::{any::Any, borrow::Cow, convert::Infallible, fmt, sync::PoisonError}; -pub use self::log::*; -use crate::error; +pub use self::{err::visit, log::*}; #[derive(thiserror::Error)] pub enum Error { From 86694f2d1d55605af2058b5347c71ebf977c5daf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 12 Nov 2024 08:01:23 +0000 Subject: [PATCH 204/245] move non-generic code out of generic; reduce codegen Signed-off-by: Jason Volk --- src/api/router/args.rs | 77 +++++++++++++++--------------- src/service/sending/send.rs | 93 ++++++++++++++++++++----------------- 2 files changed, 90 insertions(+), 80 deletions(-) diff --git a/src/api/router/args.rs b/src/api/router/args.rs index 4c0aff4c6..0b6939569 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -66,6 +66,15 @@ where } } +impl Deref for Args +where + T: IncomingRequest + Send + Sync + 'static, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { &self.body } +} + #[async_trait] impl FromRequest for Args where @@ -78,7 +87,7 @@ where let mut json_body = serde_json::from_slice::(&request.body).ok(); let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?; Ok(Self { - body: make_body::(services, &mut request, &mut json_body, &auth)?, + body: make_body::(services, &mut request, json_body.as_mut(), &auth)?, origin: auth.origin, sender_user: auth.sender_user, sender_device: auth.sender_device, @@ -88,20 +97,11 @@ where } } -impl Deref for Args -where - T: IncomingRequest + Send + Sync + 'static, -{ - type Target = T; - - fn deref(&self) -> &Self::Target { &self.body } -} - fn make_body( - services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, + services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth, ) -> Result where - T: IncomingRequest + Send + Sync + 'static, + T: IncomingRequest, { let body = take_body(services, request, json_body, auth); let http_request = into_http_request(request, body); @@ -125,36 +125,37 @@ fn into_http_request(request: &Request, body: Bytes) -> hyper::Request { http_request } +#[allow(clippy::needless_pass_by_value)] fn take_body( - services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, + services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth, ) -> Bytes { - if let Some(CanonicalJsonValue::Object(json_body)) = json_body { - let user_id = auth.sender_user.clone().unwrap_or_else(|| { - let server_name = services.globals.server_name(); - UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") + let Some(CanonicalJsonValue::Object(json_body)) = json_body else { + return mem::take(&mut request.body); + }; + + let user_id = auth.sender_user.clone().unwrap_or_else(|| { + let server_name = services.globals.server_name(); + UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") + }); + + let uiaa_request = json_body + .get("auth") + .and_then(CanonicalJsonValue::as_object) + .and_then(|auth| auth.get("session")) + .and_then(CanonicalJsonValue::as_str) + .and_then(|session| { + services + .uiaa + .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) }); - let uiaa_request = json_body - .get("auth") - .and_then(CanonicalJsonValue::as_object) - .and_then(|auth| auth.get("session")) - .and_then(CanonicalJsonValue::as_str) - .and_then(|session| { - services - .uiaa - .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) - }); - - if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { - for (key, value) in initial_request { - json_body.entry(key).or_insert(value); - } + if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + for (key, value) in initial_request { + json_body.entry(key).or_insert(value); } - - let mut buf = BytesMut::new().writer(); - serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail"); - buf.into_inner().freeze() - } else { - mem::take(&mut request.body) } + + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail"); + buf.into_inner().freeze() } diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 939d6e73d..5bf48aaab 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,5 +1,6 @@ -use std::{fmt::Debug, mem}; +use std::mem; +use bytes::Bytes; use conduit::{ debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error, Result, @@ -23,10 +24,10 @@ use crate::{ }; impl super::Service { - #[tracing::instrument(skip(self, client, req), name = "send")] - pub async fn send(&self, client: &Client, dest: &ServerName, req: T) -> Result + #[tracing::instrument(skip(self, client, request), name = "send")] + pub async fn send(&self, client: &Client, dest: &ServerName, request: T) -> Result where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { if !self.server.config.allow_federation { return Err!(Config("allow_federation", "Federation is disabled.")); @@ -42,7 +43,8 @@ impl super::Service { } let actual = self.services.resolver.get_actual_dest(dest).await?; - let request = self.prepare::(dest, &actual, req).await?; + let request = into_http_request::(&actual, request)?; + let request = self.prepare(dest, request)?; self.execute::(dest, &actual, request, client).await } @@ -50,7 +52,7 @@ impl super::Service { &self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client, ) -> Result where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { let url = request.url().clone(); let method = request.method().clone(); @@ -58,25 +60,14 @@ impl super::Service { debug!(?method, ?url, "Sending request"); match client.execute(request).await { Ok(response) => handle_response::(&self.services.resolver, dest, actual, &method, &url, response).await, - Err(error) => handle_error::(dest, actual, &method, &url, error), + Err(error) => Err(handle_error(actual, &method, &url, error).expect_err("always returns error")), } } - async fn prepare(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result - where - T: OutgoingRequest + Debug + Send, - { - const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; - const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); - - trace!("Preparing request"); - let mut http_request = req - .try_into_http_request::>(actual.string().as_str(), SATIR, &VERSIONS) - .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; + fn prepare(&self, dest: &ServerName, mut request: http::Request>) -> Result { + self.sign_request(&mut request, dest); - self.sign_request(&mut http_request, dest); - - let request = Request::try_from(http_request)?; + let request = Request::try_from(request)?; self.validate_url(request.url())?; Ok(request) @@ -96,11 +87,31 @@ impl super::Service { async fn handle_response( resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, - mut response: Response, + response: Response, ) -> Result where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { + let response = into_http_response(dest, actual, method, url, response).await?; + let result = T::IncomingResponse::try_from_http_response(response); + + if result.is_ok() && !actual.cached { + resolver.set_cached_destination( + dest.to_owned(), + CachedDest { + dest: actual.dest.clone(), + host: actual.host.clone(), + expire: CachedDest::default_expire(), + }, + ); + } + + result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) +} + +async fn into_http_response( + dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, +) -> Result> { let status = response.status(); trace!( ?status, ?method, @@ -113,6 +124,7 @@ where let mut http_response_builder = http::Response::builder() .status(status) .version(response.version()); + mem::swap( response.headers_mut(), http_response_builder @@ -137,27 +149,10 @@ where return Err(Error::Federation(dest.to_owned(), RumaError::from_http_response(http_response))); } - let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && !actual.cached { - resolver.set_cached_destination( - dest.to_owned(), - CachedDest { - dest: actual.dest.clone(), - host: actual.host.clone(), - expire: CachedDest::default_expire(), - }, - ); - } - - response.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) + Ok(http_response) } -fn handle_error( - _dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error, -) -> Result -where - T: OutgoingRequest + Debug + Send, -{ +fn handle_error(actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error) -> Result { if e.is_timeout() || e.is_connect() { e = e.without_url(); debug_warn!("{e:?}"); @@ -246,3 +241,17 @@ fn sign_request(&self, http_request: &mut http::Request>, dest: &ServerN debug_assert!(authorization.is_none(), "Authorization header already present"); } + +fn into_http_request(actual: &ActualDest, request: T) -> Result>> +where + T: OutgoingRequest + Send, +{ + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; + const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); + + let http_request = request + .try_into_http_request::>(actual.string().as_str(), SATIR, &VERSIONS) + .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; + + Ok(http_request) +} From c59f474aff0dbd96d2096d6d163629a7ecf460b5 Mon Sep 17 00:00:00 2001 From: strawberry Date: Tue, 12 Nov 2024 05:01:11 +0000 Subject: [PATCH 205/245] fixes for gh workflow Signed-off-by: Jason Volk --- .github/workflows/ci.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f59c50485..2d253f695 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,11 +76,10 @@ jobs: # large docker images sudo docker image prune --all --force || true # large packages - sudo apt-get purge -y '^llvm-.*' 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true - sudo apt-get autoremove -y + sudo apt-get purge -y 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true sudo apt-get clean # large folders - sudo rm -rf /var/lib/apt/lists/* /usr/local/games /usr/local/sqlpackage /usr/local/.ghcup /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/local/share/vcpkg /usr/local/lib/python* /usr/local/lib/node_modules /usr/local/julia* /opt/mssql-tools /etc/skel /usr/share/vim /usr/share/postgresql /usr/share/man /usr/share/apache-maven-* /usr/share/R /usr/share/alsa /usr/share/miniconda /usr/share/grub /usr/share/gradle-* /usr/share/locale /usr/share/texinfo /usr/share/kotlinc /usr/share/swift /usr/share/doc /usr/share/az_9.3.0 /usr/share/sbt /usr/share/ri /usr/share/icons /usr/share/java /usr/share/fonts /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/R /usr/lib/postgresql /usr/lib/heroku /usr/lib/gcc + sudo rm -rf /var/lib/apt/lists/* /usr/local/games /usr/local/sqlpackage /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/local/share/vcpkg /usr/local/julia* /opt/mssql-tools /usr/share/vim /usr/share/postgresql /usr/share/apache-maven-* /usr/share/R /usr/share/alsa /usr/share/miniconda /usr/share/grub /usr/share/gradle-* /usr/share/locale /usr/share/texinfo /usr/share/kotlinc /usr/share/swift /usr/share/sbt /usr/share/ri /usr/share/icons /usr/share/java /usr/share/fonts /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/R /usr/lib/postgresql /usr/lib/heroku set -o pipefail - name: Sync repository From feefa43e65e56f6d23fa96981128841fef609414 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 12 Nov 2024 22:01:32 +0000 Subject: [PATCH 206/245] add pretty/si-unit byte size parsing/printing utils Signed-off-by: Jason Volk --- Cargo.lock | 7 +++++++ Cargo.toml | 3 +++ src/core/Cargo.toml | 1 + src/core/utils/bytes.rs | 30 +++++++++++++++++++++++++++++- 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index a1654ff96..515712644 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -458,6 +458,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +[[package]] +name = "bytesize" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e368af43e418a04d52505cf3dbc23dda4e3407ae2fa99fd0e4f308ce546acc" + [[package]] name = "bzip2-sys" version = "0.1.11+1.0.8" @@ -683,6 +689,7 @@ dependencies = [ "arrayvec", "axum", "bytes", + "bytesize", "cargo_toml", "checked_ops", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 5ea6b4e09..0173e7cf9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -466,6 +466,9 @@ version = "1.0.36" [workspace.dependencies.proc-macro2] version = "1.0.89" +[workspace.dependencies.bytesize] +version = "1.3.0" + # # Patches # diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 4fe413e93..b93f9a777 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -57,6 +57,7 @@ argon2.workspace = true arrayvec.workspace = true axum.workspace = true bytes.workspace = true +bytesize.workspace = true cargo_toml.workspace = true checked_ops.workspace = true chrono.workspace = true diff --git a/src/core/utils/bytes.rs b/src/core/utils/bytes.rs index e8975a491..441ba422a 100644 --- a/src/core/utils/bytes.rs +++ b/src/core/utils/bytes.rs @@ -1,4 +1,32 @@ -use crate::Result; +use bytesize::ByteSize; + +use crate::{err, Result}; + +/// Parse a human-writable size string w/ si-unit suffix into integer +#[inline] +pub fn from_str(str: &str) -> Result { + let bytes: ByteSize = str + .parse() + .map_err(|e| err!(Arithmetic("Failed to parse byte size: {e}")))?; + + let bytes: usize = bytes + .as_u64() + .try_into() + .map_err(|e| err!(Arithmetic("Failed to convert u64 to usize: {e}")))?; + + Ok(bytes) +} + +/// Output a human-readable size string w/ si-unit suffix +#[inline] +#[must_use] +pub fn pretty(bytes: usize) -> String { + const SI_UNITS: bool = true; + + let bytes: u64 = bytes.try_into().expect("failed to convert usize to u64"); + + bytesize::to_string(bytes, SI_UNITS) +} #[inline] #[must_use] From 68582dd868032944a794f4eb7bfa2e71d29891f5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 13 Nov 2024 00:59:53 +0000 Subject: [PATCH 207/245] add parallel query for current membership state Signed-off-by: Jason Volk --- src/service/rooms/state_cache/mod.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 11684eab4..6e330fdc1 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -10,7 +10,7 @@ use conduit::{ warn, Result, }; use database::{serialize_to_vec, Deserialized, Ignore, Interfix, Json, Map}; -use futures::{stream::iter, Stream, StreamExt}; +use futures::{future::join4, stream::iter, Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ @@ -566,6 +566,24 @@ impl Service { self.db.userroomid_leftstate.qry(&key).await.is_ok() } + pub async fn user_membership(&self, user_id: &UserId, room_id: &RoomId) -> Option { + let states = join4( + self.is_joined(user_id, room_id), + self.is_left(user_id, room_id), + self.is_invited(user_id, room_id), + self.once_joined(user_id, room_id), + ) + .await; + + match states { + (true, ..) => Some(MembershipState::Join), + (_, true, ..) => Some(MembershipState::Leave), + (_, _, true, ..) => Some(MembershipState::Invite), + (false, false, false, true) => Some(MembershipState::Ban), + _ => None, + } + } + #[tracing::instrument(skip(self), level = "debug")] pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { type KeyVal<'a> = (Ignore, Vec<&'a ServerName>); From 77fab2c323b65d7d97e78dcbee946e7860cf3d1d Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 13 Nov 2024 01:01:00 +0000 Subject: [PATCH 208/245] use ruma visibility enum in directory interface Signed-off-by: Jason Volk --- src/service/rooms/directory/mod.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index f366ffe2d..63ed3519f 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use conduit::{implement, utils::stream::TryIgnore, Result}; use database::Map; use futures::Stream; -use ruma::RoomId; +use ruma::{api::client::room::Visibility, RoomId}; pub struct Service { db: Data, @@ -32,7 +32,16 @@ pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_i pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id); } #[implement(Service)] -pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.get(room_id).await.is_ok() } +pub fn public_rooms(&self) -> impl Stream + Send { self.db.publicroomids.keys().ignore_err() } #[implement(Service)] -pub fn public_rooms(&self) -> impl Stream + Send { self.db.publicroomids.keys().ignore_err() } +pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.visibility(room_id).await == Visibility::Public } + +#[implement(Service)] +pub async fn visibility(&self, room_id: &RoomId) -> Visibility { + if self.db.publicroomids.get(room_id).await.is_ok() { + Visibility::Public + } else { + Visibility::Private + } +} From 004be3bf00f3d0aa22bb07e03bd6af146ad67c7b Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 13 Nov 2024 05:28:15 +0000 Subject: [PATCH 209/245] prepare utf-8 check bypass for database deserializer Signed-off-by: Jason Volk --- src/database/de.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/database/de.rs b/src/database/de.rs index d7dc11022..f8a038ef8 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -277,7 +277,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { fn deserialize_str>(self, visitor: V) -> Result { let input = self.record_next(); - let out = string::str_from_bytes(input)?; + let out = deserialize_str(input)?; visitor.visit_borrowed_str(out) } @@ -360,3 +360,18 @@ impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> { seed.deserialize(&mut **self) } } + +// activate when stable; too soon now +//#[cfg(debug_assertions)] +#[inline] +fn deserialize_str(input: &[u8]) -> Result<&str> { string::str_from_bytes(input) } + +//#[cfg(not(debug_assertions))] +#[cfg(disable)] +#[inline] +fn deserialize_str(input: &[u8]) -> Result<&str> { + // SAFETY: Strings were written by the serializer to the database. Assuming no + // database corruption, the string will be valid. Database corruption is + // detected via rocksdb checksums. + unsafe { std::str::from_utf8_unchecked(input) } +} From 6ffdc1b2a654b2225b7ee6563e1defbf5019d32d Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 13 Nov 2024 22:01:46 +0000 Subject: [PATCH 210/245] bump serde, image, loole, termimad etc Signed-off-by: Jason Volk --- Cargo.lock | 124 +++++++++++++++++++++++++++++------------------------ Cargo.toml | 8 ++-- 2 files changed, 71 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 515712644..0e1845dad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -487,9 +487,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.37" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40545c26d092346d8a8dab71ee48e7685a7a9cba76e634790c215b41a4a7b4cf" +checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" dependencies = [ "jobserver", "libc", @@ -548,9 +548,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -558,9 +558,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstyle", "clap_lex", @@ -580,9 +580,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "cmake" @@ -720,7 +720,7 @@ dependencies = [ "serde_json", "serde_regex", "serde_yaml", - "thiserror 1.0.68", + "thiserror 1.0.69", "tikv-jemalloc-ctl", "tikv-jemalloc-sys", "tikv-jemallocator", @@ -913,9 +913,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" dependencies = [ "libc", ] @@ -1294,7 +1294,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" dependencies = [ "nonempty", - "thiserror 1.0.68", + "thiserror 1.0.69", ] [[package]] @@ -1555,7 +1555,7 @@ dependencies = [ "ipnet", "once_cell", "rand", - "thiserror 1.0.68", + "thiserror 1.0.69", "tinyvec", "tokio", "tracing", @@ -1578,7 +1578,7 @@ dependencies = [ "rand", "resolv-conf", "smallvec", - "thiserror 1.0.68", + "thiserror 1.0.69", "tokio", "tracing", ] @@ -2205,9 +2205,13 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "loole" -version = "0.3.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad95468e4700cb37d8d1f198050db18cebe55e4b4c8aa9180a715deedb2f8965" +checksum = "a2998397c725c822c6b2ba605fd9eb4c6a7a0810f1629ba3cc232ef4f0308d96" +dependencies = [ + "futures-core", + "futures-sink", +] [[package]] name = "lru-cache" @@ -2509,7 +2513,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror 1.0.68", + "thiserror 1.0.69", "urlencoding", ] @@ -2555,7 +2559,7 @@ dependencies = [ "ordered-float 4.5.0", "percent-encoding", "rand", - "thiserror 1.0.68", + "thiserror 1.0.69", "tokio", "tokio-stream", ] @@ -2918,9 +2922,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", @@ -2929,26 +2933,29 @@ dependencies = [ "rustc-hash 2.0.0", "rustls 0.23.16", "socket2", - "thiserror 1.0.68", + "thiserror 2.0.3", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", "ring", "rustc-hash 2.0.0", "rustls 0.23.16", + "rustls-pki-types", "slab", - "thiserror 1.0.68", + "thiserror 2.0.3", "tinyvec", "tracing", + "web-time 1.1.0", ] [[package]] @@ -3021,7 +3028,7 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -3036,9 +3043,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -3183,7 +3190,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "thiserror 2.0.1", + "thiserror 2.0.3", "url", "web-time 1.1.0", ] @@ -3209,7 +3216,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "thiserror 2.0.1", + "thiserror 2.0.3", "time", "tracing", "url", @@ -3235,7 +3242,7 @@ dependencies = [ "ruma-macros", "serde", "serde_json", - "thiserror 2.0.1", + "thiserror 2.0.3", "tracing", "url", "web-time 1.1.0", @@ -3266,7 +3273,7 @@ version = "0.9.5" source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" dependencies = [ "js_int", - "thiserror 2.0.1", + "thiserror 2.0.3", ] [[package]] @@ -3316,7 +3323,7 @@ dependencies = [ "http", "http-auth", "ruma-common", - "thiserror 2.0.1", + "thiserror 2.0.3", "tracing", ] @@ -3333,7 +3340,7 @@ dependencies = [ "serde_json", "sha2", "subslice", - "thiserror 2.0.1", + "thiserror 2.0.3", ] [[package]] @@ -3348,7 +3355,7 @@ dependencies = [ "ruma-events", "serde", "serde_json", - "thiserror 2.0.1", + "thiserror 2.0.3", "tracing", ] @@ -3415,9 +3422,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.39" +version = "0.38.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" +checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" dependencies = [ "bitflags 2.6.0", "errno", @@ -3483,6 +3490,9 @@ name = "rustls-pki-types" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +dependencies = [ + "web-time 1.1.0", +] [[package]] name = "rustls-webpki" @@ -3512,7 +3522,7 @@ dependencies = [ "futures-util", "pin-project", "thingbuf", - "thiserror 1.0.68", + "thiserror 1.0.69", "unicode-segmentation", "unicode-width", ] @@ -3712,7 +3722,7 @@ dependencies = [ "rand", "serde", "serde_json", - "thiserror 1.0.68", + "thiserror 1.0.69", "time", "url", "uuid", @@ -3720,18 +3730,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", @@ -3918,7 +3928,7 @@ checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" dependencies = [ "num-bigint", "num-traits", - "thiserror 1.0.68", + "thiserror 1.0.69", "time", ] @@ -4083,9 +4093,9 @@ dependencies = [ [[package]] name = "termimad" -version = "0.30.1" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22117210909e9dfff30a558f554c7fb3edb198ef614e7691386785fb7679677c" +checksum = "9cda3a7471f9978706978454c45ef8dda67e9f8f3cdb9319eb2e9323deb6ae62" dependencies = [ "coolor", "crokey", @@ -4093,7 +4103,7 @@ dependencies = [ "lazy-regex", "minimad", "serde", - "thiserror 1.0.68", + "thiserror 1.0.69", "unicode-width", ] @@ -4109,27 +4119,27 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl 1.0.68", + "thiserror-impl 1.0.69", ] [[package]] name = "thiserror" -version = "2.0.1" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07c1e40dd48a282ae8edc36c732cbc219144b87fb6a4c7316d611c6b1f06ec0c" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" dependencies = [ - "thiserror-impl 2.0.1", + "thiserror-impl 2.0.3", ] [[package]] name = "thiserror-impl" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", @@ -4138,9 +4148,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.1" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874aa7e446f1da8d9c3a5c95b1c5eb41d800045252121dc7f8e0ba370cee55f5" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" dependencies = [ "proc-macro2", "quote", @@ -4323,7 +4333,7 @@ checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" dependencies = [ "either", "futures-util", - "thiserror 1.0.68", + "thiserror 1.0.69", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 0173e7cf9..dde005a31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -142,7 +142,7 @@ features = [ ] [workspace.dependencies.serde] -version = "1.0.214" +version = "1.0.215" default-features = false features = ["rc"] @@ -171,7 +171,7 @@ default-features = false # Used to generate thumbnails for images [workspace.dependencies.image] -version = "0.25.1" +version = "0.25.5" default-features = false features = [ "jpeg", @@ -304,7 +304,7 @@ version = "2.1.1" # used to replace the channels of the tokio runtime [workspace.dependencies.loole] -version = "0.3.1" +version = "0.4.0" [workspace.dependencies.async-trait] version = "0.1.81" @@ -449,7 +449,7 @@ version = "0.4.3" default-features = false [workspace.dependencies.termimad] -version = "0.30.1" +version = "0.31.0" default-features = false [workspace.dependencies.checked_ops] From e228dec4f2d5abe02624f3d1a7cf572aab645e90 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 13 Nov 2024 01:01:33 +0000 Subject: [PATCH 211/245] add byte counting for compressed state caches Signed-off-by: Jason Volk --- src/service/rooms/state_compressor/mod.rs | 48 +++++++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index bf90d5c4d..6b520ad3d 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,18 +1,22 @@ use std::{ - collections::HashSet, + collections::{HashMap, HashSet}, fmt::Write, mem::size_of, sync::{Arc, Mutex}, }; -use conduit::{checked, err, expected, utils, utils::math::usize_from_f64, Result}; +use conduit::{ + at, checked, err, expected, utils, + utils::{bytes, math::usize_from_f64}, + Result, +}; use database::Map; use lru_cache::LruCache; use ruma::{EventId, RoomId}; use crate::{ rooms, - rooms::short::{ShortStateHash, ShortStateKey}, + rooms::short::{ShortId, ShortStateHash, ShortStateKey}, Dep, }; @@ -53,12 +57,13 @@ pub struct HashSetCompressStateEvent { pub removed: Arc, } -pub(crate) type CompressedState = HashSet; -pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; type StateInfoLruCache = LruCache; type ShortStateInfoVec = Vec; type ParentStatesVec = Vec; +pub(crate) type CompressedState = HashSet; +pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; @@ -75,9 +80,28 @@ impl crate::Service for Service { })) } - fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - let stateinfo_cache = self.stateinfo_cache.lock().expect("locked").len(); - writeln!(out, "stateinfo_cache: {stateinfo_cache}")?; + fn memory_usage(&self, out: &mut dyn Write) -> Result { + let (cache_len, ents) = { + let cache = self.stateinfo_cache.lock().expect("locked"); + let ents = cache + .iter() + .map(at!(1)) + .flat_map(|vec| vec.iter()) + .fold(HashMap::new(), |mut ents, ssi| { + ents.insert(Arc::as_ptr(&ssi.added), compressed_state_size(&ssi.added)); + ents.insert(Arc::as_ptr(&ssi.removed), compressed_state_size(&ssi.removed)); + ents.insert(Arc::as_ptr(&ssi.full_state), compressed_state_size(&ssi.full_state)); + ents + }); + + (cache.len(), ents) + }; + + let ents_len = ents.len(); + let bytes = ents.values().copied().fold(0_usize, usize::saturating_add); + + let bytes = bytes::pretty(bytes); + writeln!(out, "stateinfo_cache: {cache_len} {ents_len} ({bytes})")?; Ok(()) } @@ -435,3 +459,11 @@ impl Service { .insert(&shortstatehash.to_be_bytes(), &value); } } + +#[inline] +fn compressed_state_size(compressed_state: &CompressedState) -> usize { + compressed_state + .len() + .checked_mul(size_of::()) + .expect("CompressedState size overflow") +} From 4ec5d1e28e6cfff3d98c36c2b02aece196ee93c0 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 14 Nov 2024 04:31:29 +0000 Subject: [PATCH 212/245] replace additional use tracing:: add log:: to disallowed-macros Signed-off-by: Jason Volk --- clippy.toml | 8 ++++++++ src/api/client/report.rs | 3 +-- src/api/server/make_join.rs | 6 ++++-- src/core/debug.rs | 2 ++ src/core/log/mod.rs | 2 ++ src/core/utils/sys.rs | 4 +--- src/router/serve/plain.rs | 3 +-- 7 files changed, 19 insertions(+), 9 deletions(-) diff --git a/clippy.toml b/clippy.toml index 08641fcc1..b93b23775 100644 --- a/clippy.toml +++ b/clippy.toml @@ -5,3 +5,11 @@ future-size-threshold = 7745 # TODO reduce me ALARA stack-size-threshold = 196608 # reduce me ALARA too-many-lines-threshold = 700 # TODO reduce me to <= 100 type-complexity-threshold = 250 # reduce me to ~200 + +disallowed-macros = [ + { path = "log::error", reason = "use conduit_core::error" }, + { path = "log::warn", reason = "use conduit_core::warn" }, + { path = "log::info", reason = "use conduit_core::info" }, + { path = "log::debug", reason = "use conduit_core::debug" }, + { path = "log::trace", reason = "use conduit_core::trace" }, +] diff --git a/src/api/client/report.rs b/src/api/client/report.rs index e20fa8c22..a01337045 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -2,7 +2,7 @@ use std::time::Duration; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{utils::ReadyExt, Err}; +use conduit::{info, utils::ReadyExt, Err}; use rand::Rng; use ruma::{ api::client::{ @@ -13,7 +13,6 @@ use ruma::{ int, EventId, RoomId, UserId, }; use tokio::time::sleep; -use tracing::info; use crate::{ debug_info, diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index c3524f0e4..af5700647 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -1,5 +1,8 @@ use axum::extract::State; -use conduit::utils::{IterStream, ReadyExt}; +use conduit::{ + utils::{IterStream, ReadyExt}, + warn, +}; use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_join_event}, @@ -13,7 +16,6 @@ use ruma::{ CanonicalJsonObject, RoomId, RoomVersionId, UserId, }; use serde_json::value::to_raw_value; -use tracing::warn; use crate::{ service::{pdu::PduBuilder, Services}, diff --git a/src/core/debug.rs b/src/core/debug.rs index 85574a2f3..f7420784e 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -1,3 +1,5 @@ +#![allow(clippy::disallowed_macros)] + use std::{any::Any, panic}; // Export debug proc_macros diff --git a/src/core/log/mod.rs b/src/core/log/mod.rs index 1c415c6a0..48b7f0f38 100644 --- a/src/core/log/mod.rs +++ b/src/core/log/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::disallowed_macros)] + pub mod capture; pub mod color; pub mod fmt; diff --git a/src/core/utils/sys.rs b/src/core/utils/sys.rs index 6c396921c..af8bd70b7 100644 --- a/src/core/utils/sys.rs +++ b/src/core/utils/sys.rs @@ -1,6 +1,4 @@ -use tracing::debug; - -use crate::Result; +use crate::{debug, Result}; /// This is needed for opening lots of file descriptors, which tends to /// happen more often when using RocksDB and making lots of federation diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index 08263353b..144bff85d 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -5,9 +5,8 @@ use std::{ use axum::Router; use axum_server::{bind, Handle as ServerHandle}; -use conduit::{debug_info, Result, Server}; +use conduit::{debug_info, info, Result, Server}; use tokio::task::JoinSet; -use tracing::info; pub(super) async fn serve( server: &Arc, app: Router, handle: ServerHandle, addrs: Vec, From 08365bf5f440a4c9f086142c23044f1884c68033 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 10 Nov 2024 20:16:38 -0500 Subject: [PATCH 213/245] update config documentation, commit generated example config also removes the no-op/useless "database_backend" config option Signed-off-by: strawberry --- conduwuit-example.toml | 1747 +++++++++++++++++++++------------- src/api/server/make_knock.rs | 107 +++ src/api/server/send_knock.rs | 190 ++++ src/core/config/mod.rs | 573 +++++++---- src/service/migrations.rs | 11 +- 5 files changed, 1758 insertions(+), 870 deletions(-) create mode 100644 src/api/server/make_knock.rs create mode 100644 src/api/server/send_knock.rs diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 117356165..aa0d1e5df 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -1,945 +1,1378 @@ -# ============================================================================= -# This is the official example config for conduwuit. -# If you use it for your server, you will need to adjust it to your own needs. -# At the very least, change the server_name field! -# -# This documentation can also be found at https://conduwuit.puppyirl.gay/configuration.html -# ============================================================================= +### conduwuit Configuration +### +### THIS FILE IS GENERATED. CHANGES/CONTRIBUTIONS IN THE REPO WILL +### BE OVERWRITTEN! +### +### You should rename this file before configuring your server. Changes +### to documentation and defaults can be contributed in source code at +### src/core/config/mod.rs. This file is generated when building. +### +### Any values pre-populated are the default values for said config option. +### +### At the minimum, you MUST edit all the config options to your environment +### that say "YOU NEED TO EDIT THIS". +### See https://conduwuit.puppyirl.gay/configuration.html for ways to +### configure conduwuit [global] -# The server_name is the pretty name of this server. It is used as a suffix for user -# and room ids. Examples: matrix.org, conduit.rs - -# The Conduit server needs all /_matrix/ requests to be reachable at -# https://your.server.name/ on port 443 (client-server) and 8448 (federation). +# The server_name is the pretty name of this server. It is used as a +# suffix for user and room IDs/aliases. +# +# See the docs for reverse proxying and delegation: https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy +# Also see the `[global.well_known]` config section at the very bottom. +# +# Examples of delegation: +# - https://puppygock.gay/.well-known/matrix/server +# - https://puppygock.gay/.well-known/matrix/client +# +# YOU NEED TO EDIT THIS. THIS CANNOT BE CHANGED AFTER WITHOUT A DATABASE +# WIPE. +# +# example: "conduwuit.woof" +# +#server_name = -# If that's not possible for you, you can create /.well-known files to redirect -# requests (delegation). See -# https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixclient -# and -# https://spec.matrix.org/v1.9/server-server-api/#getwell-knownmatrixserver -# for more information +# default address (IPv4 or IPv6) conduwuit will listen on. +# +# If you are using Docker or a container NAT networking setup, this must +# be "0.0.0.0". +# +# To listen on multiple addresses, specify a vector e.g. ["127.0.0.1", +# "::1"] +# +#address = ["127.0.0.1", "::1"] -# YOU NEED TO EDIT THIS -#server_name = "your.server.name" +# The port(s) conduwuit will be running on. +# +# See https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy for reverse proxying. +# +# Docker users: Don't change this, you'll need to map an external port to +# this. +# +# To listen on multiple ports, specify a vector e.g. [8080, 8448] +# +#port = 8008 -# Servers listed here will be used to gather public keys of other servers (notary trusted key servers). +# Uncomment unix_socket_path to listen on a UNIX socket at the specified +# path. If listening on a UNIX socket, you MUST remove/comment the +# 'address' key if definedm AND add your reverse proxy to the 'conduwuit' +# group, unless world RW permissions are specified with unix_socket_perms +# (666 minimum). # -# The default behaviour for conduwuit is to attempt to query trusted key servers before querying the individual servers. -# This is done for performance reasons, but if you would like to query individual servers before the notary servers -# configured below, set to +# example: "/run/conduwuit/conduwuit.sock" # -# (Currently, conduwuit doesn't support batched key requests, so this list should only contain Synapse servers) -# Defaults to `matrix.org` -# trusted_servers = ["matrix.org"] +#unix_socket_path = -# Sentry.io crash/panic reporting, performance monitoring/metrics, etc. This is NOT enabled by default. -# conduwuit's default Sentry reporting endpoint is o4506996327251968.ingest.us.sentry.io +# The default permissions (in octal) to create the UNIX socket with. # -# Defaults to *false* -#sentry = false +#unix_socket_perms = 660 -# Sentry reporting URL if a custom one is desired +# This is the only directory where conduwuit will save its data, including +# media. +# Note: this was previously "/var/lib/matrix-conduit" +# +# YOU NEED TO EDIT THIS. # -# Defaults to conduwuit's default Sentry endpoint: "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" -#sentry_endpoint = "" +# example: "/var/lib/conduwuit" +# +#database_path = -# Report your Conduwuit server_name in Sentry.io crash reports and metrics +# conduwuit supports online database backups using RocksDB's Backup engine +# API. To use this, set a database backup path that conduwuit can write +# to. # -# Defaults to false -#sentry_send_server_name = false +# See https://conduwuit.puppyirl.gay/maintenance.html#backups for more information. +# +# example: "/opt/conduwuit-db-backups" +# +#database_backup_path = -# Performance monitoring/tracing sample rate for Sentry.io +# The amount of online RocksDB database backups to keep/retain, if using +# "database_backup_path", before deleting the oldest one. +# +#database_backups_to_keep = 1 + +# Set this to any float value in megabytes for conduwuit to tell the +# database engine that this much memory is available for database-related +# caches. # -# Note that too high values may impact performance, and can be disabled by setting it to 0.0 (0%) -# This value is read as a percentage to Sentry, represented as a decimal +# May be useful if you have significant memory to spare to increase +# performance. # -# Defaults to 15% of traces (0.15) -#sentry_traces_sample_rate = 0.15 +# Similar to the individual LRU caches, this is scaled up with your CPU +# core count. +# +# This defaults to 128.0 + (64.0 * CPU core count) +# +#db_cache_capacity_mb = -# Whether to attach a stacktrace to Sentry reports. -#sentry_attach_stacktrace = false +# Option to control adding arbitrary text to the end of the user's +# displayname upon registration with a space before the text. This was the +# lightning bolt emoji option, just replaced with support for adding your +# own custom text or emojis. To disable, set this to "" (an empty string). +# +# The default is the trans pride flag. +# +# example: "🏳️⚧️" +# +#new_user_displayname_suffix = "🏳️⚧️" -# Send panics to sentry. This is true by default, but sentry has to be enabled. -#sentry_send_panic = true +# If enabled, conduwuit will send a simple GET request periodically to +# `https://pupbrain.dev/check-for-updates/stable` for any new +# announcements made. Despite the name, this is not an update check +# endpoint, it is simply an announcement check endpoint. +# +# This is disabled by default as this is rarely used except for security +# updates or major updates. +# +#allow_check_for_updates = false -# Send errors to sentry. This is true by default, but sentry has to be enabled. This option is -# only effective in release-mode; forced to false in debug-mode. -#sentry_send_error = true +# Set this to any float value to multiply conduwuit's in-memory LRU caches +# with such as "auth_chain_cache_capacity". +# +# May be useful if you have significant memory to spare to increase +# performance. This was previously called +# `conduit_cache_capacity_modifier`. +# +# If you have low memory, reducing this may be viable. +# +# By default, the individual caches such as "auth_chain_cache_capacity" +# are scaled by your CPU core count. +# +#cache_capacity_modifier = 1.0 -# Controls the tracing log level for Sentry to send things like breadcrumbs and transactions -# Defaults to "info" -#sentry_filter = "info" +# This item is undocumented. Please contribute documentation for it. +# +#pdu_cache_capacity = varies by system +# This item is undocumented. Please contribute documentation for it. +# +#auth_chain_cache_capacity = varies by system -### Database configuration +# This item is undocumented. Please contribute documentation for it. +# +#shorteventid_cache_capacity = varies by system -# This is the only directory where conduwuit will save its data, including media. -# Note: this was previously "/var/lib/matrix-conduit" -database_path = "/var/lib/conduwuit" +# This item is undocumented. Please contribute documentation for it. +# +#eventidshort_cache_capacity = varies by system -# Database backend: Only rocksdb is supported. -database_backend = "rocksdb" +# This item is undocumented. Please contribute documentation for it. +# +#shortstatekey_cache_capacity = varies by system +# This item is undocumented. Please contribute documentation for it. +# +#statekeyshort_cache_capacity = varies by system -### Network +# This item is undocumented. Please contribute documentation for it. +# +#server_visibility_cache_capacity = varies by system -# The port(s) conduwuit will be running on. You need to set up a reverse proxy such as -# Caddy or Nginx so all requests to /_matrix on port 443 and 8448 will be -# forwarded to the conduwuit instance running on this port -# Docker users: Don't change this, you'll need to map an external port to this. -# To listen on multiple ports, specify a vector e.g. [8080, 8448] +# This item is undocumented. Please contribute documentation for it. # -# default if unspecified is 8008 -port = 6167 +#user_visibility_cache_capacity = varies by system -# default address (IPv4 or IPv6) conduwuit will listen on. Generally you want this to be -# localhost (127.0.0.1 / ::1). If you are using Docker or a container NAT networking setup, you -# likely need this to be 0.0.0.0. -# To listen multiple addresses, specify a vector e.g. ["127.0.0.1", "::1"] +# This item is undocumented. Please contribute documentation for it. # -# default if unspecified is both IPv4 and IPv6 localhost: ["127.0.0.1", "::1"] -address = "127.0.0.1" +#stateinfo_cache_capacity = varies by system -# Max request size for file uploads -max_request_size = 20_000_000 # in bytes +# This item is undocumented. Please contribute documentation for it. +# +#roomid_spacehierarchy_cache_capacity = varies by system -# Uncomment unix_socket_path to listen on a UNIX socket at the specified path. -# If listening on a UNIX socket, you must remove/comment the 'address' key if defined and add your -# reverse proxy to the 'conduwuit' group, unless world RW permissions are specified with unix_socket_perms (666 minimum). -#unix_socket_path = "/run/conduwuit/conduwuit.sock" -#unix_socket_perms = 660 +# Maximum entries stored in DNS memory-cache. The size of an entry may +# vary so please take care if raising this value excessively. Only +# decrease this when using an external DNS cache. Please note +# that systemd-resolved does *not* count as an external cache, even when +# configured to do so. +# +#dns_cache_entries = 32768 -# Set this to true for conduwuit to compress HTTP response bodies using zstd. -# This option does nothing if conduwuit was not built with `zstd_compression` feature. -# Please be aware that enabling HTTP compression may weaken TLS. -# Most users should not need to enable this. -# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before deciding to enable this. -zstd_compression = false - -# Set this to true for conduwuit to compress HTTP response bodies using gzip. -# This option does nothing if conduwuit was not built with `gzip_compression` feature. -# Please be aware that enabling HTTP compression may weaken TLS. -# Most users should not need to enable this. -# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before deciding to enable this. -gzip_compression = false - -# Set this to true for conduwuit to compress HTTP response bodies using brotli. -# This option does nothing if conduwuit was not built with `brotli_compression` feature. -# Please be aware that enabling HTTP compression may weaken TLS. -# Most users should not need to enable this. -# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before deciding to enable this. -brotli_compression = false - -# Vector list of IPv4 and IPv6 CIDR ranges / subnets *in quotes* that you do not want conduwuit to send outbound requests to. -# Defaults to RFC1918, unroutable, loopback, multicast, and testnet addresses for security. +# Minimum time-to-live in seconds for entries in the DNS cache. The +# default may appear high to most administrators; this is by design as the +# majority of NXDOMAINs are correct for a long time (e.g. the server is no +# longer running Matrix). Only decrease this if you are using an external +# DNS cache. # -# To disable, set this to be an empty vector (`[]`). -# Please be aware that this is *not* a guarantee. You should be using a firewall with zones as doing this on the application layer may have bypasses. +# default_dns_min_ttl: 259200 # -# Currently this does not account for proxies in use like Synapse does. -ip_range_denylist = [ - "127.0.0.0/8", - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "100.64.0.0/10", - "192.0.0.0/24", - "169.254.0.0/16", - "192.88.99.0/24", - "198.18.0.0/15", - "192.0.2.0/24", - "198.51.100.0/24", - "203.0.113.0/24", - "224.0.0.0/4", - "::1/128", - "fe80::/10", - "fc00::/7", - "2001:db8::/32", - "ff00::/8", - "fec0::/10", -] - - -### Moderation / Privacy / Security - -# Config option to control whether the legacy unauthenticated Matrix media repository endpoints will be enabled. -# These endpoints consist of: -# - /_matrix/media/*/config -# - /_matrix/media/*/upload -# - /_matrix/media/*/preview_url -# - /_matrix/media/*/download/* -# - /_matrix/media/*/thumbnail/* +#dns_min_ttl = + +# Minimum time-to-live in seconds for NXDOMAIN entries in the DNS cache. +# This value is critical for the server to federate efficiently. +# NXDOMAIN's are assumed to not be returning to the federation +# and aggressively cached rather than constantly rechecked. # -# The authenticated equivalent endpoints are always enabled. +# Defaults to 3 days as these are *very rarely* false negatives. # -# Defaults to true for now, but this is highly subject to change, likely in the next release. -#allow_legacy_media = true +#dns_min_ttl_nxdomain = 259200 -# Set to true to allow user type "guest" registrations. Element attempts to register guest users automatically. -# Defaults to false -allow_guest_registration = false +# Number of retries after a timeout. +# +#dns_attempts = 10 -# Set to true to log guest registrations in the admin room. -# Defaults to false as it may be noisy or unnecessary. -log_guest_registrations = false +# The number of seconds to wait for a reply to a DNS query. Please note +# that recursive queries can take up to several seconds for some domains, +# so this value should not be too low, especially on slower hardware or +# resolvers. +# +#dns_timeout = 10 -# Set to true to allow guest registrations/users to auto join any rooms specified in `auto_join_rooms` -# Defaults to false -allow_guests_auto_join_rooms = false +# Fallback to TCP on DNS errors. Set this to false if unsupported by +# nameserver. +# +#dns_tcp_fallback = true -# Vector list of servers that conduwuit will refuse to download remote media from. -# No default. -# prevent_media_downloads_from = ["example.com", "example.local"] +# Enable to query all nameservers until the domain is found. Referred to +# as "trust_negative_responses" in hickory_resolver. This can avoid +# useless DNS queries if the first nameserver responds with NXDOMAIN or +# an empty NOERROR response. +# +#query_all_nameservers = true -# Enables registration. If set to false, no users can register on this -# server. +# Enables using *only* TCP for querying your specified nameservers instead +# of UDP. # -# If set to true without a token configured, users can register with no form of 2nd- -# step only if you set -# `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to -# true in your config. +# If you are running conduwuit in a container environment, this config option may need to be enabled. See https://conduwuit.puppyirl.gay/troubleshooting.html#potential-dns-issues-when-using-docker for more details. # -# If you would like registration only via token reg, please configure -# `registration_token` or `registration_token_file`. -allow_registration = false -# Please note that an open registration homeserver with no second-step verification -# is highly prone to abuse and potential defederation by homeservers, including -# matrix.org. +#query_over_tcp_only = false -# A static registration token that new users will have to provide when creating -# an account. If unset and `allow_registration` is true, registration is open -# without any condition. YOU NEED TO EDIT THIS. -registration_token = "change this token/string here or set registration_token_file" +# DNS A/AAAA record lookup strategy +# +# Takes a number of one of the following options: +# 1 - Ipv4Only (Only query for A records, no AAAA/IPv6) +# +# 2 - Ipv6Only (Only query for AAAA records, no A/IPv4) +# +# 3 - Ipv4AndIpv6 (Query for A and AAAA records in parallel, uses whatever +# returns a successful response first) +# +# 4 - Ipv6thenIpv4 (Query for AAAA record, if that fails then query the A +# record) +# +# 5 - Ipv4thenIpv6 (Query for A record, if that fails then query the AAAA +# record) +# +# If you don't have IPv6 networking, then for better DNS performance it +# may be suitable to set this to Ipv4Only (1) as you will never ever use +# the AAAA record contents even if the AAAA record is successful instead +# of the A record. +# +#ip_lookup_strategy = 5 -# Path to a file on the system that gets read for the registration token +# Max request size for file uploads in bytes. Defaults to 20MB. # -# conduwuit must be able to access the file, and it must not be empty +#max_request_size = 20971520 + +# This item is undocumented. Please contribute documentation for it. # -# no default -#registration_token_file = "/etc/conduwuit/.reg_token" +#max_fetch_prev_events = 192 -# controls whether federation is allowed or not -# defaults to true -# allow_federation = true +# Default/base connection timeout (seconds). This is used only by URL +# previews and update/news endpoint checks. +# +#request_conn_timeout = 10 -# controls whether users are allowed to create rooms. -# appservices and admins are always allowed to create rooms -# defaults to true -# allow_room_creation = true +# Default/base request timeout (seconds). The time waiting to receive more +# data from another server. This is used only by URL previews, +# update/news, and misc endpoint checks. +# +#request_timeout = 35 -# controls whether non-admin local users are forbidden from sending room invites (local and remote), -# and if non-admin users can receive remote room invites. admins are always allowed to send and receive all room invites. -# defaults to false -# block_non_admin_invites = false +# Default/base request total timeout (seconds). The time limit for a whole +# request. This is set very high to not cancel healthy requests while +# serving as a backstop. This is used only by URL previews and +# update/news endpoint checks. +# +#request_total_timeout = 320 -# List of forbidden username patterns/strings. Values in this list are matched as *contains*. -# This is checked upon username availability check, registration, and startup as warnings if any local users in your database -# have a forbidden username. -# No default. -# forbidden_usernames = [] +# Default/base idle connection pool timeout (seconds). This is used only +# by URL previews and update/news endpoint checks. +# +#request_idle_timeout = 5 -# List of forbidden room aliases and room IDs as patterns/strings. Values in this list are matched as *contains*. -# This is checked upon room alias creation, custom room ID creation if used, and startup as warnings if any room aliases -# in your database have a forbidden room alias/ID. -# No default. -# forbidden_alias_names = [] +# Default/base max idle connections per host. This is used only by URL +# previews and update/news endpoint checks. Defaults to 1 as generally the +# same open connection can be re-used. +# +#request_idle_per_host = 1 -# List of forbidden server names that we will block incoming AND outgoing federation with, and block client room joins / remote user invites. +# Federation well-known resolution connection timeout (seconds) # -# This check is applied on the room ID, room alias, sender server name, sender user's server name, inbound federation X-Matrix origin, and outbound federation handler. +#well_known_conn_timeout = 6 + +# Federation HTTP well-known resolution request timeout (seconds) # -# Basically "global" ACLs. No default. -# forbidden_remote_server_names = [] +#well_known_timeout = 10 -# List of forbidden server names that we will block all outgoing federated room directory requests for. Useful for preventing our users from wandering into bad servers or spaces. -# No default. -# forbidden_remote_room_directory_server_names = [] +# Federation client request timeout (seconds). You most definitely want +# this to be high to account for extremely large room joins, slow +# homeservers, your own resources etc. +# +#federation_timeout = 300 -# Set this to true to allow your server's public room directory to be federated. -# Set this to false to protect against /publicRooms spiders, but will forbid external users -# from viewing your server's public room directory. If federation is disabled entirely -# (`allow_federation`), this is inherently false. -allow_public_room_directory_over_federation = false +# Federation client idle connection pool timeout (seconds) +# +#federation_idle_timeout = 25 -# Set this to true to allow your server's public room directory to be queried without client -# authentication (access token) through the Client APIs. Set this to false to protect against /publicRooms spiders. -allow_public_room_directory_without_auth = false +# Federation client max idle connections per host. Defaults to 1 as +# generally the same open connection can be re-used +# +#federation_idle_per_host = 1 -# Set this to true to lock down your server's public room directory and only allow admins to publish rooms to the room directory. -# Unpublishing is still allowed by all users with this enabled. +# Federation sender request timeout (seconds). The time it takes for the +# remote server to process sent transactions can take a while. # -# Defaults to false -lockdown_public_room_directory = false +#sender_timeout = 180 -# Set this to true to allow federating device display names / allow external users to see your device display name. -# If federation is disabled entirely (`allow_federation`), this is inherently false. For privacy, this is best disabled. -allow_device_name_federation = false +# Federation sender idle connection pool timeout (seconds) +# +#sender_idle_timeout = 180 -# Vector list of domains allowed to send requests to for URL previews. Defaults to none. -# Note: this is a *contains* match, not an explicit match. Putting "google.com" will match "https://google.com" and "http://mymaliciousdomainexamplegoogle.com" -# Setting this to "*" will allow all URL previews. Please note that this opens up significant attack surface to your server, you are expected to be aware of the risks by doing so. -url_preview_domain_contains_allowlist = [] +# Federation sender transaction retry backoff limit (seconds) +# +#sender_retry_backoff_limit = 86400 -# Vector list of explicit domains allowed to send requests to for URL previews. Defaults to none. -# Note: This is an *explicit* match, not a contains match. Putting "google.com" will match "https://google.com", "http://google.com", but not "https://mymaliciousdomainexamplegoogle.com" -# Setting this to "*" will allow all URL previews. Please note that this opens up significant attack surface to your server, you are expected to be aware of the risks by doing so. -url_preview_domain_explicit_allowlist = [] +# Appservice URL request connection timeout. Defaults to 35 seconds as +# generally appservices are hosted within the same network. +# +#appservice_timeout = 35 -# Vector list of URLs allowed to send requests to for URL previews. Defaults to none. -# Note that this is a *contains* match, not an explicit match. Putting "google.com" will match "https://google.com/", "https://google.com/url?q=https://mymaliciousdomainexample.com", and "https://mymaliciousdomainexample.com/hi/google.com" -# Setting this to "*" will allow all URL previews. Please note that this opens up significant attack surface to your server, you are expected to be aware of the risks by doing so. -url_preview_url_contains_allowlist = [] +# Appservice URL idle connection pool timeout (seconds) +# +#appservice_idle_timeout = 300 -# Vector list of explicit domains not allowed to send requests to for URL previews. Defaults to none. -# Note: This is an *explicit* match, not a contains match. Putting "google.com" will match "https://google.com", "http://google.com", but not "https://mymaliciousdomainexamplegoogle.com" -# The denylist is checked first before allowlist. Setting this to "*" will not do anything. -url_preview_domain_explicit_denylist = [] +# Notification gateway pusher idle connection pool timeout +# +#pusher_idle_timeout = 15 -# Maximum amount of bytes allowed in a URL preview body size when spidering. Defaults to 384KB (384_000 bytes) -url_preview_max_spider_size = 384_000 +# Enables registration. If set to false, no users can register on this +# server. +# +# If set to true without a token configured, users can register with no +# form of 2nd-step only if you set +# `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to +# true in your config. +# +# If you would like registration only via token reg, please configure +# `registration_token` or `registration_token_file`. +# +#allow_registration = false -# Option to decide whether you would like to run the domain allowlist checks (contains and explicit) on the root domain or not. Does not apply to URL contains allowlist. Defaults to false. -# Example: If this is enabled and you have "wikipedia.org" allowed in the explicit and/or contains domain allowlist, it will allow all subdomains under "wikipedia.org" such as "en.m.wikipedia.org" as the root domain is checked and matched. -# Useful if the domain contains allowlist is still too broad for you but you still want to allow all the subdomains under a root domain. -url_preview_check_root_domain = false +# This item is undocumented. Please contribute documentation for it. +# +#yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse = false -# Config option to allow or disallow incoming federation requests that obtain the profiles -# of our local users from `/_matrix/federation/v1/query/profile` +# A static registration token that new users will have to provide when +# creating an account. If unset and `allow_registration` is true, +# registration is open without any condition. # -# This is inherently false if `allow_federation` is disabled +# YOU NEED TO EDIT THIS OR USE registration_token_file. +# +# example: "o&^uCtes4HPf0Vu@F20jQeeWE7" # -# Defaults to true -allow_profile_lookup_federation_requests = true +#registration_token = -# Config option to automatically deactivate the account of any user who attempts to join a: -# - banned room -# - forbidden room alias -# - room alias or ID with a forbidden server name +# Path to a file on the system that gets read for the registration token. +# this config option takes precedence/priority over "registration_token". # -# This may be useful if all your banned lists consist of toxic rooms or servers that no good faith user would ever attempt to join, and -# to automatically remediate the problem without any admin user intervention. +# conduwuit must be able to access the file, and it must not be empty # -# This will also make the user leave all rooms. Federation (e.g. remote room invites) are ignored here. +# example: "/etc/conduwuit/.reg_token" # -# Defaults to false as rooms can be banned for non-moderation-related reasons -#auto_deactivate_banned_room_attempts = false +#registration_token_file = +# Controls whether encrypted rooms and events are allowed. +# +#allow_encryption = true -### Admin Room and Console +# Controls whether federation is allowed or not. It is not recommended to +# disable this after the fact due to potential federation breakage. +# +#allow_federation = true -# Controls whether the conduwuit admin room console / CLI will immediately activate on startup. -# This option can also be enabled with `--console` conduwuit argument +# This item is undocumented. Please contribute documentation for it. # -# Defaults to false -#admin_console_automatic = false +#federation_loopback = false -# Controls what admin commands will be executed on startup. This is a vector list of strings of admin commands to run. +# Set this to true to require authentication on the normally +# unauthenticated profile retrieval endpoints (GET) +# "/_matrix/client/v3/profile/{userId}". # -# An example of this can be: `admin_execute = ["debug ping puppygock.gay", "debug echo hi"]` +# This can prevent profile scraping. # -# This option can also be configured with the `--execute` conduwuit argument and can take standard shell commands and environment variables +#require_auth_for_profile_requests = false + +# Set this to true to allow your server's public room directory to be +# federated. Set this to false to protect against /publicRooms spiders, +# but will forbid external users from viewing your server's public room +# directory. If federation is disabled entirely (`allow_federation`), +# this is inherently false. # -# Such example could be: `./conduwuit --execute "server admin-notice conduwuit has started up at $(date)"` +#allow_public_room_directory_over_federation = false + +# Set this to true to allow your server's public room directory to be +# queried without client authentication (access token) through the Client +# APIs. Set this to false to protect against /publicRooms spiders. # -# Defaults to nothing. -#admin_execute = [""] +#allow_public_room_directory_without_auth = false -# Controls whether conduwuit should error and fail to start if an admin execute command (`--execute` / `admin_execute`) fails +# allow guests/unauthenticated users to access TURN credentials # -# Defaults to false -#admin_execute_errors_ignore = false +# this is the equivalent of Synapse's `turn_allow_guests` config option. +# this allows any unauthenticated user to call the endpoint +# `/_matrix/client/v3/voip/turnServer`. +# +# It is unlikely you need to enable this as all major clients support +# authentication for this endpoint and prevents misuse of your TURN server +# from potential bots. +# +#turn_allow_guests = false -# Controls the max log level for admin command log captures (logs generated from running admin commands) +# Set this to true to lock down your server's public room directory and +# only allow admins to publish rooms to the room directory. Unpublishing +# is still allowed by all users with this enabled. # -# Defaults to "info" on release builds, else "debug" on debug builds -#admin_log_capture = "info" +#lockdown_public_room_directory = false -# Allows admins to enter commands in rooms other than #admins by prefixing with \!admin. The reply -# will be publicly visible to the room, originating from the sender. -# defaults to true -#admin_escape_commands = true +# Set this to true to allow federating device display names / allow +# external users to see your device display name. If federation is +# disabled entirely (`allow_federation`), this is inherently false. For +# privacy reasons, this is best left disabled. +# +#allow_device_name_federation = false -# Controls whether admin room notices like account registrations, password changes, account deactivations, -# room directory publications, etc will be sent to the admin room. +# Config option to allow or disallow incoming federation requests that +# obtain the profiles of our local users from +# `/_matrix/federation/v1/query/profile` # -# Update notices and normal admin command responses will still be sent. +# Increases privacy of your local user's such as display names, but some +# remote users may get a false "this user does not exist" error when they +# try to invite you to a DM or room. Also can protect against profile +# spiders. # -# defaults to true -#admin_room_notices = true +# This is inherently false if `allow_federation` is disabled +# +#allow_inbound_profile_lookup_federation_requests = true +# controls whether standard users are allowed to create rooms. appservices +# and admins are always allowed to create rooms +# +#allow_room_creation = true -### Misc +# Set to false to disable users from joining or creating room versions +# that aren't 100% officially supported by conduwuit. +# +# conduwuit officially supports room versions 6 - 11. +# +# conduwuit has slightly experimental (though works fine in practice) +# support for versions 3 - 5 +# +#allow_unstable_room_versions = true -# max log level for conduwuit. allows debug, info, warn, or error -# see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives -# **Caveat**: -# For release builds, the tracing crate is configured to only implement levels higher than error to avoid unnecessary overhead in the compiled binary from trace macros. -# For debug builds, this restriction is not applied. +# default room version conduwuit will create rooms with. # -# Defaults to "info" -#log = "info" +# per spec, room version 10 is the default. +# +#default_room_version = 10 -# controls whether logs will be outputted with ANSI colours +# This item is undocumented. Please contribute documentation for it. # -# defaults to true -#log_colors = true +#allow_jaeger = false -# controls whether encrypted rooms and events are allowed (default true) -#allow_encryption = false +# This item is undocumented. Please contribute documentation for it. +# +#jaeger_filter = "info" -# if enabled, conduwuit will send a simple GET request periodically to `https://pupbrain.dev/check-for-updates/stable` -# for any new announcements made. Despite the name, this is not an update check -# endpoint, it is simply an announcement check endpoint. -# Defaults to false. -#allow_check_for_updates = false +# If the 'perf_measurements' compile-time feature is enabled, enables +# collecting folded stack trace profile of tracing spans using +# tracing_flame. The resulting profile can be visualized with inferno[1], +# speedscope[2], or a number of other tools. +# +# [1]: https://github.com/jonhoo/inferno +# [2]: www.speedscope.app +# +#tracing_flame = false -# Set to false to disable users from joining or creating room versions that aren't 100% officially supported by conduwuit. -# conduwuit officially supports room versions 6 - 10. conduwuit has experimental/unstable support for 3 - 5, and 11. -# Defaults to true. -#allow_unstable_room_versions = true +# This item is undocumented. Please contribute documentation for it. +# +#tracing_flame_filter = "info" -# Option to control adding arbitrary text to the end of the user's displayname upon registration with a space before the text. -# This was the lightning bolt emoji option, just replaced with support for adding your own custom text or emojis. -# To disable, set this to "" (an empty string) -# Defaults to "🏳️‍⚧️" (trans pride flag) -#new_user_displayname_suffix = "🏳️‍⚧️" +# This item is undocumented. Please contribute documentation for it. +# +#tracing_flame_output_path = "./tracing.folded" -# Option to control whether conduwuit will query your list of trusted notary key servers (`trusted_servers`) for -# remote homeserver signing keys it doesn't know *first*, or query the individual servers first before falling back to the trusted -# key servers. +# Examples: +# - No proxy (default): +# proxy ="none" # -# The former/default behaviour makes federated/remote rooms joins generally faster because we're querying a single (or list of) server -# that we know works, is reasonably fast, and is reliable for just about all the homeserver signing keys in the room. Querying individual -# servers may take longer depending on the general infrastructure of everyone in there, how many dead servers there are, etc. +# - For global proxy, create the section at the bottom of this file: +# [global.proxy] +# global = { url = "socks5h://localhost:9050" } # -# However, this does create an increased reliance on one single or multiple large entities as `trusted_servers` should generally -# contain long-term and large servers who know a very large number of homeservers. +# - To proxy some domains: +# [global.proxy] +# [[global.proxy.by_domain]] +# url = "socks5h://localhost:9050" +# include = ["*.onion", "matrix.myspecial.onion"] +# exclude = ["*.myspecial.onion"] # -# If you don't know what any of this means, leave this and `trusted_servers` alone to their defaults. +# Include vs. Exclude: +# - If include is an empty list, it is assumed to be `["*"]`. +# - If a domain matches both the exclude and include list, the proxy will +# only be used if it was included because of a more specific rule than +# it was excluded. In the above example, the proxy would be used for +# `ordinary.onion`, `matrix.myspecial.onion`, but not +# `hello.myspecial.onion`. # -# Defaults to true as this is the fastest option for federation. -#query_trusted_key_servers_first = true +#proxy = "none" -# List/vector of room **IDs** that conduwuit will make newly registered users join. -# The room IDs specified must be rooms that you have joined at least once on the server, and must be public. +# This item is undocumented. Please contribute documentation for it. # -# No default. -#auto_join_rooms = [] +#jwt_secret = -# Retry failed and incomplete messages to remote servers immediately upon startup. This is called bursting. -# If this is disabled, said messages may not be delivered until more messages are queued for that server. -# Do not change this option unless server resources are extremely limited or the scale of the server's -# deployment is huge. Do not disable this unless you know what you are doing. -#startup_netburst = true - -# Limit the startup netburst to the most recent (default: 50) messages queued for each remote server. All older -# messages are dropped and not reattempted. The `startup_netburst` option must be enabled for this value to have -# any effect. Do not change this value unless you know what you are doing. Set this value to -1 to reattempt -# every message without trimming the queues; this may consume significant disk. Set this value to 0 to drop all -# messages without any attempt at redelivery. -#startup_netburst_keep = 50 - -# If the 'perf_measurements' feature is enabled, enables collecting folded stack trace profile of tracing spans using -# tracing_flame. The resulting profile can be visualized with inferno[1], speedscope[2], or a number of other tools. -# [1]: https://github.com/jonhoo/inferno -# [2]: www.speedscope.app -# tracing_flame = false - -# If 'tracing_flame' is enabled, sets a filter for which events will be included in the profile. -# Supported syntax is documented at https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives -# tracing_flame_filter = "trace,h2=off" +# Servers listed here will be used to gather public keys of other servers +# (notary trusted key servers). +# +# Currently, conduwuit doesn't support inbound batched key requests, so +# this list should only contain other Synapse servers +# +# example: ["matrix.org", "constellatory.net", "tchncs.de"] +# +#trusted_servers = ["matrix.org"] -# If 'tracing_flame' is enabled, set the path to write the generated profile. -# tracing_flame_output_path = "./tracing.folded" +# Whether to query the servers listed in trusted_servers first or query +# the origin server first. For best security, querying the origin server +# first is advised to minimize the exposure to a compromised trusted +# server. For maximum federation/join performance this can be set to true, +# however other options exist to query trusted servers first under +# specific high-load circumstances and should be evaluated before setting +# this to true. +# +#query_trusted_key_servers_first = false -# Enable the tokio-console. This option is only relevant to developers. -# See: docs/development.md#debugging-with-tokio-console for more information. -#tokio_console = false +# Whether to query the servers listed in trusted_servers first +# specifically on room joins. This option limits the exposure to a +# compromised trusted server to room joins only. The join operation +# requires gathering keys from many origin servers which can cause +# significant delays. Therefor this defaults to true to mitigate +# unexpected delays out-of-the-box. The security-paranoid or those +# willing to tolerate delays are advised to set this to false. Note that +# setting query_trusted_key_servers_first to true causes this option to +# be ignored. +# +#query_trusted_key_servers_first_on_join = true -# Enable backward-compatibility with Conduit's media directory by creating symlinks of media. This -# option is only necessary if you plan on using Conduit again. Otherwise setting this to false -# reduces filesystem clutter and overhead for managing these symlinks in the directory. This is now -# disabled by default. You may still return to upstream Conduit but you have to run Conduwuit at -# least once with this set to true and allow the media_startup_check to take place before shutting -# down to return to Conduit. +# Only query trusted servers for keys and never the origin server. This is +# intended for clusters or custom deployments using their trusted_servers +# as forwarding-agents to cache and deduplicate requests. Notary servers +# do not act as forwarding-agents by default, therefor do not enable this +# unless you know exactly what you are doing. # -# Disabled by default. -#media_compat_file_link = false +#only_query_trusted_key_servers = false -# Prunes missing media from the database as part of the media startup checks. This means if you -# delete files from the media directory the corresponding entries will be removed from the -# database. This is disabled by default because if the media directory is accidentally moved or -# inaccessible the metadata entries in the database will be lost with sadness. +# Maximum number of keys to request in each trusted server batch query. # -# Disabled by default. -#prune_missing_media = false +#trusted_server_batch_size = 1024 -# Checks consistency of the media directory at startup: -# 1. When `media_compat_file_link` is enbled, this check will upgrade media when switching back -# and forth between Conduit and Conduwuit. Both options must be enabled to handle this. -# 2. When media is deleted from the directory, this check will also delete its database entry. +# max log level for conduwuit. allows debug, info, warn, or error +# see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives # -# If none of these checks apply to your use cases, and your media directory is significantly large -# setting this to false may reduce startup time. +# **Caveat**: +# For release builds, the tracing crate is configured to only implement +# levels higher than error to avoid unnecessary overhead in the compiled +# binary from trace macros. For debug builds, this restriction is not +# applied. # -# Enabled by default. -#media_startup_check = true +#log = "info" + +# controls whether logs will be outputted with ANSI colours +# +#log_colors = true # OpenID token expiration/TTL in seconds # -# These are the OpenID tokens that are primarily used for Matrix account integrations, *not* OIDC/OpenID Connect/etc +# These are the OpenID tokens that are primarily used for Matrix account +# integrations (e.g. Vector Integrations in Element), *not* OIDC/OpenID +# Connect/etc # -# Defaults to 3600 (1 hour) #openid_token_ttl = 3600 -# Emergency password feature. This password set here will let you login to the server service account (e.g. `@conduit`) -# and let you run admin commands, invite yourself to the admin room, etc. +# static TURN username to provide the client if not using a shared secret +# ("turn_secret"), It is recommended to use a shared secret over static +# credentials. # -# no default. -#emergency_password = "" - +#turn_username = false -### Generic database options +# static TURN password to provide the client if not using a shared secret +# ("turn_secret"). It is recommended to use a shared secret over static +# credentials. +# +#turn_password = false -# Set this to any float value to multiply conduwuit's in-memory LRU caches with. -# By default, the caches scale automatically with cpu-core-count. -# May be useful if you have significant memory to spare to increase performance. +# vector list of TURN URIs/servers to use # -# This was previously called `conduit_cache_capacity_modifier` +# replace "example.turn.uri" with your TURN domain, such as the coturn +# "realm" config option. if using TURN over TLS, replace the URI prefix +# "turn:" with "turns:" # -# Defaults to 1.0. -#cache_capacity_modifier = 1.0 +# example: ["turn:example.turn.uri?transport=udp", +# "turn:example.turn.uri?transport=tcp"] +# +#turn_uris = [] -# Set this to any float value in megabytes for conduwuit to tell the database engine that this much memory is available for database-related caches. -# May be useful if you have significant memory to spare to increase performance. -# Defaults to 128.0 + (64.0 * CPU core count). -#db_cache_capacity_mb = 256.0 +# TURN secret to use for generating the HMAC-SHA1 hash apart of username +# and password generation +# +# this is more secure, but if needed you can use traditional +# static username/password credentials. +# +#turn_secret = false +# TURN secret to use that's read from the file path specified +# +# this takes priority over "turn_secret" first, and falls back to +# "turn_secret" if invalid or failed to open. +# +# example: "/etc/conduwuit/.turn_secret" +# +#turn_secret_file = -### RocksDB options +# TURN TTL in seconds +# +#turn_ttl = 86400 -# Set this to true to use RocksDB config options that are tailored to HDDs (slower device storage) +# List/vector of room IDs or room aliases that conduwuit will make newly +# registered users join. The rooms specified must be rooms that you +# have joined at least once on the server, and must be public. # -# It is worth noting that by default, conduwuit will use RocksDB with Direct IO enabled. *Generally* speaking this improves performance as it bypasses buffered I/O (system page cache). -# However there is a potential chance that Direct IO may cause issues with database operations if your setup is uncommon. This has been observed with FUSE filesystems, and possibly ZFS filesystem. -# RocksDB generally deals/corrects these issues but it cannot account for all setups. -# If you experience any weird RocksDB issues, try enabling this option as it turns off Direct IO and feel free to report in the conduwuit Matrix room if this option fixes your DB issues. -# See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. +# example: ["#conduwuit:puppygock.gay", +# "!eoIzvAvVwY23LPDay8:puppygock.gay"] # -# Defaults to false -#rocksdb_optimize_for_spinning_disks = false +#auto_join_rooms = [] -# Enables direct-io to increase database performance. This is enabled by default. Set this option to false if the -# database resides on a filesystem which does not support direct-io. -#rocksdb_direct_io = true +# Config option to automatically deactivate the account of any user who +# attempts to join a: +# - banned room +# - forbidden room alias +# - room alias or ID with a forbidden server name +# +# This may be useful if all your banned lists consist of toxic rooms or +# servers that no good faith user would ever attempt to join, and +# to automatically remediate the problem without any admin user +# intervention. +# +# This will also make the user leave all rooms. Federation (e.g. remote +# room invites) are ignored here. +# +# Defaults to false as rooms can be banned for non-moderation-related +# reasons +# +#auto_deactivate_banned_room_attempts = false -# RocksDB log level. This is not the same as conduwuit's log level. This is the log level for the RocksDB engine/library -# which show up in your database folder/path as `LOG` files. Defaults to error. conduwuit will typically log RocksDB errors as normal. +# RocksDB log level. This is not the same as conduwuit's log level. This +# is the log level for the RocksDB engine/library which show up in your +# database folder/path as `LOG` files. conduwuit will log RocksDB errors +# as normal through tracing. +# #rocksdb_log_level = "error" -# Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB. +# This item is undocumented. Please contribute documentation for it. +# +#rocksdb_log_stderr = false + +# Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB in +# bytes. +# #rocksdb_max_log_file_size = 4194304 -# Time in seconds before RocksDB will forcibly rotate logs. Defaults to 0. +# Time in seconds before RocksDB will forcibly rotate logs. +# #rocksdb_log_time_to_roll = 0 -# Amount of threads that RocksDB will use for parallelism on database operatons such as cleanup, sync, flush, compaction, etc. Set to 0 to use all your logical threads. +# Set this to true to use RocksDB config options that are tailored to HDDs +# (slower device storage) # -# Defaults to your CPU logical thread count. -#rocksdb_parallelism_threads = 0 +# It is worth noting that by default, conduwuit will use RocksDB with +# Direct IO enabled. *Generally* speaking this improves performance as it +# bypasses buffered I/O (system page cache). However there is a potential +# chance that Direct IO may cause issues with database operations if your +# setup is uncommon. This has been observed with FUSE filesystems, and +# possibly ZFS filesystem. RocksDB generally deals/corrects these issues +# but it cannot account for all setups. If you experience any weird +# RocksDB issues, try enabling this option as it turns off Direct IO and +# feel free to report in the conduwuit Matrix room if this option fixes +# your DB issues. +# +# See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. +# +#rocksdb_optimize_for_spinning_disks = false -# Enables idle IO priority for compaction thread. This prevents any unexpected lag in the server's operation and -# is usually a good idea. Enabled by default. -#rocksdb_compaction_ioprio_idle = true +# Enables direct-io to increase database performance via unbuffered I/O. +# +# See https://github.com/facebook/rocksdb/wiki/Direct-IO for more details about Direct IO and RocksDB. +# +# Set this option to false if the database resides on a filesystem which +# does not support direct-io like FUSE, or any form of complex filesystem +# setup such as possibly ZFS. +# +#rocksdb_direct_io = true -# Enables idle CPU priority for compaction thread. This is not enabled by default to prevent compaction from -# falling too far behind on busy systems. -#rocksdb_compaction_prio_idle = false +# Amount of threads that RocksDB will use for parallelism on database +# operatons such as cleanup, sync, flush, compaction, etc. Set to 0 to use +# all your logical threads. Defaults to your CPU logical thread count. +# +#rocksdb_parallelism_threads = 0 -# Maximum number of LOG files RocksDB will keep. This must *not* be set to 0. It must be at least 1. -# Defaults to 3 as these are not very useful. +# Maximum number of LOG files RocksDB will keep. This must *not* be set to +# 0. It must be at least 1. Defaults to 3 as these are not very useful +# unless troubleshooting/debugging a RocksDB bug. +# #rocksdb_max_log_files = 3 # Type of RocksDB database compression to use. +# # Available options are "zstd", "zlib", "bz2", "lz4", or "none" -# It is best to use ZSTD as an overall good balance between speed/performance, storage, IO amplification, and CPU usage. -# For more performance but less compression (more storage used) and less CPU usage, use LZ4. -# See https://github.com/facebook/rocksdb/wiki/Compression for more details. +# +# It is best to use ZSTD as an overall good balance between +# speed/performance, storage, IO amplification, and CPU usage. +# For more performance but less compression (more storage used) and less +# CPU usage, use LZ4. See https://github.com/facebook/rocksdb/wiki/Compression for more details. # # "none" will disable compression. # -# Defaults to "zstd" #rocksdb_compression_algo = "zstd" -# Level of compression the specified compression algorithm for RocksDB to use. -# Default is 32767, which is internally read by RocksDB as the default magic number and -# translated to the library's default compression level as they all differ. +# Level of compression the specified compression algorithm for RocksDB to +# use. +# +# Default is 32767, which is internally read by RocksDB as the +# default magic number and translated to the library's default +# compression level as they all differ. # See their `kDefaultCompressionLevel`. # #rocksdb_compression_level = 32767 -# Level of compression the specified compression algorithm for the bottommost level/data for RocksDB to use. -# Default is 32767, which is internally read by RocksDB as the default magic number and -# translated to the library's default compression level as they all differ. +# Level of compression the specified compression algorithm for the +# bottommost level/data for RocksDB to use. Default is 32767, which is +# internally read by RocksDB as the default magic number and translated +# to the library's default compression level as they all differ. # See their `kDefaultCompressionLevel`. # -# Since this is the bottommost level (generally old and least used data), it may be desirable to have a very -# high compression level here as it's lesss likely for this data to be used. Research your chosen compression algorithm. +# Since this is the bottommost level (generally old and least used data), +# it may be desirable to have a very high compression level here as it's +# lesss likely for this data to be used. Research your chosen compression +# algorithm. # #rocksdb_bottommost_compression_level = 32767 -# Whether to enable RocksDB "bottommost_compression". -# At the expense of more CPU usage, this will further compress the database to reduce more storage. -# It is recommended to use ZSTD compression with this for best compression results. +# Whether to enable RocksDB's "bottommost_compression". +# +# At the expense of more CPU usage, this will further compress the +# database to reduce more storage. It is recommended to use ZSTD +# compression with this for best compression results. This may be useful +# if you're trying to reduce storage usage from the database. +# # See https://github.com/facebook/rocksdb/wiki/Compression for more details. # -# Defaults to false as this uses more CPU when compressing. #rocksdb_bottommost_compression = false -# Level of statistics collection. Some admin commands to display database statistics may require -# this option to be set. Database performance may be impacted by higher settings. +# Database recovery mode (for RocksDB WAL corruption) # -# Option is a number ranging from 0 to 6: -# 0 = No statistics. -# 1 = No statistics in release mode (default). -# 2 to 3 = Statistics with no performance impact. -# 3 to 5 = Statistics with possible performance impact. -# 6 = All statistics. +# Use this option when the server reports corruption and refuses to start. +# Set mode 2 (PointInTime) to cleanly recover from this corruption. The +# server will continue from the last good state, several seconds or +# minutes prior to the crash. Clients may have to run "clear-cache & +# reload" to account for the rollback. Upon success, you may reset the +# mode back to default and restart again. Please note in some cases the +# corruption error may not be cleared for at least 30 minutes of +# operation in PointInTime mode. # -# Defaults to 1 (No statistics, except in debug-mode) -#rocksdb_stats_level = 1 +# As a very last ditch effort, if PointInTime does not fix or resolve +# anything, you can try mode 3 (SkipAnyCorruptedRecord) but this will +# leave the server in a potentially inconsistent state. +# +# The default mode 1 (TolerateCorruptedTailRecords) will automatically +# drop the last entry in the database if corrupted during shutdown, but +# nothing more. It is extraordinarily unlikely this will desynchronize +# clients. To disable any form of silent rollback set mode 0 +# (AbsoluteConsistency). +# +# The options are: +# 0 = AbsoluteConsistency +# 1 = TolerateCorruptedTailRecords (default) +# 2 = PointInTime (use me if trying to recover) +# 3 = SkipAnyCorruptedRecord (you now voided your Conduwuit warranty) +# +# See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information on these modes. +# +# See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. +# +#rocksdb_recovery_mode = 1 # Database repair mode (for RocksDB SST corruption) # -# Use this option when the server reports corruption while running or panics. If the server refuses -# to start use the recovery mode options first. Corruption errors containing the acronym 'SST' which -# occur after startup will likely require this option. +# Use this option when the server reports corruption while running or +# panics. If the server refuses to start use the recovery mode options +# first. Corruption errors containing the acronym 'SST' which occur after +# startup will likely require this option. +# +# - Backing up your database directory is recommended prior to running the +# repair. +# - Disabling repair mode and restarting the server is recommended after +# running the repair. # -# - Backing up your database directory is recommended prior to running the repair. -# - Disabling repair mode and restarting the server is recommended after running the repair. +# See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. # -# Defaults to false #rocksdb_repair = false -# Database recovery mode (for RocksDB WAL corruption) +# This item is undocumented. Please contribute documentation for it. # -# Use this option when the server reports corruption and refuses to start. Set mode 2 (PointInTime) -# to cleanly recover from this corruption. The server will continue from the last good state, -# several seconds or minutes prior to the crash. Clients may have to run "clear-cache & reload" to -# account for the rollback. Upon success, you may reset the mode back to default and restart again. -# Please note in some cases the corruption error may not be cleared for at least 30 minutes of -# operation in PointInTime mode. +#rocksdb_read_only = false + +# This item is undocumented. Please contribute documentation for it. # -# As a very last ditch effort, if PointInTime does not fix or resolve anything, you can try mode -# 3 (SkipAnyCorruptedRecord) but this will leave the server in a potentially inconsistent state. +#rocksdb_secondary = false + +# Enables idle CPU priority for compaction thread. This is not enabled by +# default to prevent compaction from falling too far behind on busy +# systems. # -# The default mode 1 (TolerateCorruptedTailRecords) will automatically drop the last entry in the -# database if corrupted during shutdown, but nothing more. It is extraordinarily unlikely this will -# desynchronize clients. To disable any form of silent rollback set mode 0 (AbsoluteConsistency). +#rocksdb_compaction_prio_idle = false + +# Enables idle IO priority for compaction thread. This prevents any +# unexpected lag in the server's operation and is usually a good idea. +# Enabled by default. # -# The options are: -# 0 = AbsoluteConsistency -# 1 = TolerateCorruptedTailRecords (default) -# 2 = PointInTime (use me if trying to recover) -# 3 = SkipAnyCorruptedRecord (you now voided your Conduwuit warranty) +#rocksdb_compaction_ioprio_idle = true + +# Config option to disable RocksDB compaction. You should never ever have +# to disable this. If you for some reason find yourself needing to disable +# this as part of troubleshooting or a bug, please reach out to us in the +# conduwuit Matrix room with information and details. # -# See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information +# Disabling compaction will lead to a significantly bloated and +# explosively large database, gradually poor performance, unnecessarily +# excessive disk read/writes, and slower shutdowns and startups. # -# Defaults to 1 (TolerateCorruptedTailRecords) -#rocksdb_recovery_mode = 1 +#rocksdb_compaction = true +# Level of statistics collection. Some admin commands to display database +# statistics may require this option to be set. Database performance may +# be impacted by higher settings. +# +# Option is a number ranging from 0 to 6: +# 0 = No statistics. +# 1 = No statistics in release mode (default). +# 2 to 3 = Statistics with no performance impact. +# 3 to 5 = Statistics with possible performance impact. +# 6 = All statistics. +# +#rocksdb_stats_level = 1 -### Domain Name Resolution and Caching +# This is a password that can be configured that will let you login to the +# server bot account (currently `@conduit`) for emergency troubleshooting +# purposes such as recovering/recreating your admin room, or inviting +# yourself back. +# +# See https://conduwuit.puppyirl.gay/troubleshooting.html#lost-access-to-admin-room for other ways to get back into your admin room. +# +# Once this password is unset, all sessions will be logged out for +# security purposes. +# +# example: "F670$2CP@Hw8mG7RY1$%!#Ic7YA" +# +#emergency_password = -# Maximum entries stored in DNS memory-cache. The size of an entry may vary so please take care if -# raising this value excessively. Only decrease this when using an external DNS cache. Please note -# that systemd does *not* count as an external cache, even when configured to do so. -#dns_cache_entries = 32768 +# This item is undocumented. Please contribute documentation for it. +# +#notification_push_path = "/_matrix/push/v1/notify" -# Minimum time-to-live in seconds for entries in the DNS cache. The default may appear high to most -# administrators; this is by design. Only decrease this if you are using an external DNS cache. -#dns_min_ttl = 10800 +# Config option to control local (your server only) presence +# updates/requests. Note that presence on conduwuit is +# very fast unlike Synapse's. If using outgoing presence, this MUST be +# enabled. +# +#allow_local_presence = true -# Minimum time-to-live in seconds for NXDOMAIN entries in the DNS cache. This value is critical for -# the server to federate efficiently. NXDOMAIN's are assumed to not be returning to the federation -# and aggressively cached rather than constantly rechecked. +# Config option to control incoming federated presence updates/requests. # -# Defaults to 3 days as these are *very rarely* false negatives. -#dns_min_ttl_nxdomain = 259200 +# This option receives presence updates from other +# servers, but does not send any unless `allow_outgoing_presence` is true. +# Note that presence on conduwuit is very fast unlike Synapse's. +# +#allow_incoming_presence = true -# The number of seconds to wait for a reply to a DNS query. Please note that recursive queries can -# take up to several seconds for some domains, so this value should not be too low. -#dns_timeout = 10 +# Config option to control outgoing presence updates/requests. +# +# This option sends presence updates to other servers, but does not +# receive any unless `allow_incoming_presence` is true. +# Note that presence on conduwuit is very fast unlike Synapse's. +# If using outgoing presence, you MUST enable `allow_local_presence` as +# well. +# +#allow_outgoing_presence = true -# Number of retries after a timeout. -#dns_attempts = 10 +# Config option to control how many seconds before presence updates that +# you are idle. Defaults to 5 minutes. +# +#presence_idle_timeout_s = 300 -# Fallback to TCP on DNS errors. Set this to false if unsupported by nameserver. -#dns_tcp_fallback = true +# Config option to control how many seconds before presence updates that +# you are offline. Defaults to 30 minutes. +# +#presence_offline_timeout_s = 1800 -# Enable to query all nameservers until the domain is found. Referred to as "trust_negative_responses" in hickory_resolver. -# This can avoid useless DNS queries if the first nameserver responds with NXDOMAIN or an empty NOERROR response. +# Config option to enable the presence idle timer for remote users. +# Disabling is offered as an optimization for servers participating in +# many large rooms or when resources are limited. Disabling it may cause +# incorrect presence states (i.e. stuck online) to be seen for some +# remote users. # -# The default is to query one nameserver and stop (false). -#query_all_nameservers = true +#presence_timeout_remote_users = true -# Enables using *only* TCP for querying your specified nameservers instead of UDP. +# Config option to control whether we should receive remote incoming read +# receipts. # -# You very likely do *not* want this. hickory-resolver already falls back to TCP on UDP errors. -# Defaults to false -#query_over_tcp_only = false +#allow_incoming_read_receipts = true -# DNS A/AAAA record lookup strategy +# Config option to control whether we should send read receipts to remote +# servers. # -# Takes a number of one of the following options: -# 1 - Ipv4Only (Only query for A records, no AAAA/IPv6) -# 2 - Ipv6Only (Only query for AAAA records, no A/IPv4) -# 3 - Ipv4AndIpv6 (Query for A and AAAA records in parallel, uses whatever returns a successful response first) -# 4 - Ipv6thenIpv4 (Query for AAAA record, if that fails then query the A record) -# 5 - Ipv4thenIpv6 (Query for A record, if that fails then query the AAAA record) +#allow_outgoing_read_receipts = true + +# Config option to control outgoing typing updates to federation. # -# If you don't have IPv6 networking, then for better performance it may be suitable to set this to Ipv4Only (1) as -# you will never ever use the AAAA record contents even if the AAAA record is successful instead of the A record. +#allow_outgoing_typing = true + +# Config option to control incoming typing updates from federation. # -# Defaults to 5 - Ipv4ThenIpv6 as this is the most compatible and IPv4 networking is currently the most prevalent. -#ip_lookup_strategy = 5 +#allow_incoming_typing = true +# Config option to control maximum time federation user can indicate +# typing. +# +#typing_federation_timeout_s = 30 -### Request Timeouts, Connection Timeouts, and Connection Pooling +# Config option to control minimum time local client can indicate typing. +# This does not override a client's request to stop typing. It only +# enforces a minimum value in case of no stop request. +# +#typing_client_timeout_min_s = 15 -## Request Timeouts are HTTP response timeouts -## Connection Timeouts are TCP connection timeouts -## -## Connection Pooling Timeouts are timeouts for keeping an open idle connection alive. -## Connection pooling and keepalive is very useful for federation or other places where for performance reasons, -## we want to keep connections open that we will re-use frequently due to TCP and TLS 1.3 overhead/expensiveness. -## -## Generally these defaults are the best, but if you find a reason to need to change these they are here. +# Config option to control maximum time local client can indicate typing. +# +#typing_client_timeout_max_s = 45 -# Default/base connection timeout. -# This is used only by URL previews and update/news endpoint checks +# Set this to true for conduwuit to compress HTTP response bodies using +# zstd. This option does nothing if conduwuit was not built with +# `zstd_compression` feature. Please be aware that enabling HTTP +# compression may weaken TLS. Most users should not need to enable this. +# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH +# before deciding to enable this. # -# Defaults to 10 seconds -#request_conn_timeout = 10 +#zstd_compression = false -# Default/base request timeout. The time waiting to receive more data from another server. -# This is used only by URL previews, update/news, and misc endpoint checks +# Set this to true for conduwuit to compress HTTP response bodies using +# gzip. This option does nothing if conduwuit was not built with +# `gzip_compression` feature. Please be aware that enabling HTTP +# compression may weaken TLS. Most users should not need to enable this. +# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before +# deciding to enable this. # -# Defaults to 35 seconds -#request_timeout = 35 +# If you are in a large amount of rooms, you may find that enabling this +# is necessary to reduce the significantly large response bodies. +# +#gzip_compression = false -# Default/base request total timeout. The time limit for a whole request. This is set very high to not -# cancel healthy requests while serving as a backstop. -# This is used only by URL previews and update/news endpoint checks +# Set this to true for conduwuit to compress HTTP response bodies using +# brotli. This option does nothing if conduwuit was not built with +# `brotli_compression` feature. Please be aware that enabling HTTP +# compression may weaken TLS. Most users should not need to enable this. +# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before +# deciding to enable this. # -# Defaults to 320 seconds -#request_total_timeout = 320 +#brotli_compression = false -# Default/base idle connection pool timeout -# This is used only by URL previews and update/news endpoint checks +# Set to true to allow user type "guest" registrations. Some clients like +# Element attempt to register guest users automatically. # -# Defaults to 5 seconds -#request_idle_timeout = 5 +#allow_guest_registration = false -# Default/base max idle connections per host -# This is used only by URL previews and update/news endpoint checks +# Set to true to log guest registrations in the admin room. Note that +# these may be noisy or unnecessary if you're a public homeserver. # -# Defaults to 1 as generally the same open connection can be re-used -#request_idle_per_host = 1 +#log_guest_registrations = false -# Federation well-known resolution connection timeout +# Set to true to allow guest registrations/users to auto join any rooms +# specified in `auto_join_rooms`. # -# Defaults to 6 seconds -#well_known_conn_timeout = 6 +#allow_guests_auto_join_rooms = false -# Federation HTTP well-known resolution request timeout +# Config option to control whether the legacy unauthenticated Matrix media +# repository endpoints will be enabled. These endpoints consist of: +# - /_matrix/media/*/config +# - /_matrix/media/*/upload +# - /_matrix/media/*/preview_url +# - /_matrix/media/*/download/* +# - /_matrix/media/*/thumbnail/* # -# Defaults to 10 seconds -#well_known_timeout = 10 +# The authenticated equivalent endpoints are always enabled. +# +# Defaults to true for now, but this is highly subject to change, likely +# in the next release. +# +#allow_legacy_media = true -# Federation client request timeout -# You most definitely want this to be high to account for extremely large room joins, slow homeservers, your own resources etc. +# This item is undocumented. Please contribute documentation for it. # -# Defaults to 300 seconds -#federation_timeout = 300 +#freeze_legacy_media = true -# Federation client idle connection pool timeout +# Checks consistency of the media directory at startup: +# 1. When `media_compat_file_link` is enbled, this check will upgrade +# media when switching back and forth between Conduit and conduwuit. +# Both options must be enabled to handle this. +# 2. When media is deleted from the directory, this check will also delete +# its database entry. # -# Defaults to 25 seconds -#federation_idle_timeout = 25 +# If none of these checks apply to your use cases, and your media +# directory is significantly large setting this to false may reduce +# startup time. +# +#media_startup_check = true -# Federation client max idle connections per host +# Enable backward-compatibility with Conduit's media directory by creating +# symlinks of media. This option is only necessary if you plan on using +# Conduit again. Otherwise setting this to false reduces filesystem +# clutter and overhead for managing these symlinks in the directory. This +# is now disabled by default. You may still return to upstream Conduit +# but you have to run conduwuit at least once with this set to true and +# allow the media_startup_check to take place before shutting +# down to return to Conduit. # -# Defaults to 1 as generally the same open connection can be re-used -#federation_idle_per_host = 1 +#media_compat_file_link = false -# Federation sender request timeout -# The time it takes for the remote server to process sent transactions can take a while. +# Prunes missing media from the database as part of the media startup +# checks. This means if you delete files from the media directory the +# corresponding entries will be removed from the database. This is +# disabled by default because if the media directory is accidentally moved +# or inaccessible, the metadata entries in the database will be lost with +# sadness. # -# Defaults to 180 seconds -#sender_timeout = 180 +#prune_missing_media = false -# Federation sender idle connection pool timeout +# Vector list of servers that conduwuit will refuse to download remote +# media from. # -# Defaults to 180 seconds -#sender_idle_timeout = 180 +#prevent_media_downloads_from = [] -# Federation sender transaction retry backoff limit +# List of forbidden server names that we will block incoming AND outgoing +# federation with, and block client room joins / remote user invites. # -# Defaults to 86400 seconds -#sender_retry_backoff_limit = 86400 +# This check is applied on the room ID, room alias, sender server name, +# sender user's server name, inbound federation X-Matrix origin, and +# outbound federation handler. +# +# Basically "global" ACLs. +# +#forbidden_remote_server_names = [] -# Appservice URL request connection timeout +# List of forbidden server names that we will block all outgoing federated +# room directory requests for. Useful for preventing our users from +# wandering into bad servers or spaces. # -# Defaults to 35 seconds as generally appservices are hosted within the same network -#appservice_timeout = 35 +#forbidden_remote_room_directory_server_names = [] -# Appservice URL idle connection pool timeout +# Vector list of IPv4 and IPv6 CIDR ranges / subnets *in quotes* that you +# do not want conduwuit to send outbound requests to. Defaults to +# RFC1918, unroutable, loopback, multicast, and testnet addresses for +# security. # -# Defaults to 300 seconds -#appservice_idle_timeout = 300 +# Please be aware that this is *not* a guarantee. You should be using a +# firewall with zones as doing this on the application layer may have +# bypasses. +# +# Currently this does not account for proxies in use like Synapse does. +# +# To disable, set this to be an empty vector (`[]`). +# +# "192.168.0.0/16", "100.64.0.0/10", "192.0.0.0/24", "169.254.0.0/16", +# "192.88.99.0/24", "198.18.0.0/15", "192.0.2.0/24", "198.51.100.0/24", +# "203.0.113.0/24", "224.0.0.0/4", "::1/128", "fe80::/10", "fc00::/7", +# "2001:db8::/32", "ff00::/8", "fec0::/10"] +# +#ip_range_denylist = ["127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", -# Notification gateway pusher idle connection pool timeout +# Vector list of domains allowed to send requests to for URL previews. +# Defaults to none. Note: this is a *contains* match, not an explicit +# match. Putting "google.com" will match "https://google.com" and +# "http://mymaliciousdomainexamplegoogle.com" Setting this to "*" will +# allow all URL previews. Please note that this opens up significant +# attack surface to your server, you are expected to be aware of the +# risks by doing so. # -# Defaults to 15 seconds -#pusher_idle_timeout = 15 +#url_preview_domain_contains_allowlist = [] +# Vector list of explicit domains allowed to send requests to for URL +# previews. Defaults to none. Note: This is an *explicit* match, not a +# contains match. Putting "google.com" will match "https://google.com", +# "http://google.com", but not +# "https://mymaliciousdomainexamplegoogle.com". Setting this to "*" will +# allow all URL previews. Please note that this opens up significant +# attack surface to your server, you are expected to be aware of the +# risks by doing so. +# +#url_preview_domain_explicit_allowlist = [] -### Presence / Typing Indicators / Read Receipts +# Vector list of explicit domains not allowed to send requests to for URL +# previews. Defaults to none. Note: This is an *explicit* match, not a +# contains match. Putting "google.com" will match "https://google.com", +# "http://google.com", but not +# "https://mymaliciousdomainexamplegoogle.com". The denylist is checked +# first before allowlist. Setting this to "*" will not do anything. +# +#url_preview_domain_explicit_denylist = [] -# Config option to control local (your server only) presence updates/requests. Defaults to true. -# Note that presence on conduwuit is very fast unlike Synapse's. -# If using outgoing presence, this MUST be enabled. +# Vector list of URLs allowed to send requests to for URL previews. +# Defaults to none. Note that this is a *contains* match, not an +# explicit match. Putting "google.com" will match +# "https://google.com/", +# "https://google.com/url?q=https://mymaliciousdomainexample.com", and +# "https://mymaliciousdomainexample.com/hi/google.com" Setting this to +# "*" will allow all URL previews. Please note that this opens up +# significant attack surface to your server, you are expected to be +# aware of the risks by doing so. # -#allow_local_presence = true +#url_preview_url_contains_allowlist = [] -# Config option to control incoming federated presence updates/requests. Defaults to true. -# This option receives presence updates from other servers, but does not send any unless `allow_outgoing_presence` is true. -# Note that presence on conduwuit is very fast unlike Synapse's. +# Maximum amount of bytes allowed in a URL preview body size when +# spidering. Defaults to 384KB in bytes. # -#allow_incoming_presence = true +#url_preview_max_spider_size = 384000 -# Config option to control outgoing presence updates/requests. Defaults to true. -# This option sends presence updates to other servers, but does not receive any unless `allow_incoming_presence` is true. -# Note that presence on conduwuit is very fast unlike Synapse's. -# If using outgoing presence, you MUST enable `allow_local_presence` as well. +# Option to decide whether you would like to run the domain allowlist +# checks (contains and explicit) on the root domain or not. Does not apply +# to URL contains allowlist. Defaults to false. # -#allow_outgoing_presence = true +# Example usecase: If this is +# enabled and you have "wikipedia.org" allowed in the explicit and/or +# contains domain allowlist, it will allow all subdomains under +# "wikipedia.org" such as "en.m.wikipedia.org" as the root domain is +# checked and matched. Useful if the domain contains allowlist is still +# too broad for you but you still want to allow all the subdomains under a +# root domain. +# +#url_preview_check_root_domain = false -# Config option to enable the presence idle timer for remote users. Disabling is offered as an optimization for -# servers participating in many large rooms or when resources are limited. Disabling it may cause incorrect -# presence states (i.e. stuck online) to be seen for some remote users. Defaults to true. -#presence_timeout_remote_users = true +# List of forbidden room aliases and room IDs as strings of regex +# patterns. +# +# Regex can be used or explicit contains matches can be done by +# just specifying the words (see example). +# +# This is checked upon room alias creation, custom room ID creation if +# used, and startup as warnings if any room aliases in your database have +# a forbidden room alias/ID. +# +# example: ["19dollarfortnitecards", "b[4a]droom"] +# +#forbidden_alias_names = [] -# Config option to control how many seconds before presence updates that you are idle. Defaults to 5 minutes. -#presence_idle_timeout_s = 300 +# List of forbidden username patterns/strings. +# +# Regex can be used or explicit contains matches can be done by just +# specifying the words (see example). +# +# This is checked upon username availability check, registration, and +# startup as warnings if any local users in your database have a forbidden +# username. +# +# example: ["administrator", "b[a4]dusernam[3e]"] +# +#forbidden_usernames = [] -# Config option to control how many seconds before presence updates that you are offline. Defaults to 30 minutes. -#presence_offline_timeout_s = 1800 +# Retry failed and incomplete messages to remote servers immediately upon +# startup. This is called bursting. If this is disabled, said messages +# may not be delivered until more messages are queued for that server. Do +# not change this option unless server resources are extremely limited or +# the scale of the server's deployment is huge. Do not disable this +# unless you know what you are doing. +# +#startup_netburst = true -# Config option to control whether we should receive remote incoming read receipts. -# Defaults to true. -#allow_incoming_read_receipts = true +# messages are dropped and not reattempted. The `startup_netburst` option +# must be enabled for this value to have any effect. Do not change this +# value unless you know what you are doing. Set this value to -1 to +# reattempt every message without trimming the queues; this may consume +# significant disk. Set this value to 0 to drop all messages without any +# attempt at redelivery. +# +#startup_netburst_keep = 50 -# Config option to control whether we should send read receipts to remote servers. -# Defaults to true. -#allow_outgoing_read_receipts = true +# controls whether non-admin local users are forbidden from sending room +# invites (local and remote), and if non-admin users can receive remote +# room invites. admins are always allowed to send and receive all room +# invites. +# +#block_non_admin_invites = false -# Config option to control outgoing typing updates to federation. Defaults to true. -#allow_outgoing_typing = true +# Allows admins to enter commands in rooms other than "#admins" (admin +# room) by prefixing your message with "\!admin" or "\\!admin" followed +# up a normal conduwuit admin command. The reply will be publicly visible +# to the room, originating from the sender. +# +# example: \\!admin debug ping puppygock.gay +# +#admin_escape_commands = true -# Config option to control incoming typing updates from federation. Defaults to true. -#allow_incoming_typing = true +# Controls whether the conduwuit admin room console / CLI will immediately +# activate on startup. This option can also be enabled with `--console` +# conduwuit argument. +# +#admin_console_automatic = false -# Config option to control maximum time federation user can indicate typing. -#typing_federation_timeout_s = 30 +# Controls what admin commands will be executed on startup. This is a +# vector list of strings of admin commands to run. +# +# +# This option can also be configured with the `--execute` conduwuit +# argument and can take standard shell commands and environment variables +# +# Such example could be: `./conduwuit --execute "server admin-notice +# conduwuit has started up at $(date)"` +# +# example: admin_execute = ["debug ping puppygock.gay", "debug echo hi"]` +# +#admin_execute = [] -# Config option to control minimum time local client can indicate typing. This does not override -# a client's request to stop typing. It only enforces a minimum value in case of no stop request. -#typing_client_timeout_min_s = 15 +# Controls whether conduwuit should error and fail to start if an admin +# execute command (`--execute` / `admin_execute`) fails. +# +#admin_execute_errors_ignore = false -# Config option to control maximum time local client can indicate typing. -#typing_client_timeout_max_s = 45 +# Controls the max log level for admin command log captures (logs +# generated from running admin commands). Defaults to "info" on release +# builds, else "debug" on debug builds. +# +#admin_log_capture = "info" +# The default room tag to apply on the admin room. +# +# On some clients like Element, the room tag "m.server_notice" is a +# special pinned room at the very bottom of your room list. The conduwuit +# admin room can be pinned here so you always have an easy-to-access +# shortcut dedicated to your admin room. +# +#admin_room_tag = "m.server_notice" -### TURN / VoIP +# Sentry.io crash/panic reporting, performance monitoring/metrics, etc. +# This is NOT enabled by default. conduwuit's default Sentry reporting +# endpoint is o4506996327251968.ingest.us.sentry.io +# +#sentry = false -# vector list of TURN URIs/servers to use +# Sentry reporting URL if a custom one is desired # -# replace "example.turn.uri" with your TURN domain, such as the coturn "realm". -# if using TURN over TLS, replace "turn:" with "turns:" +#sentry_endpoint = "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" + +# Report your conduwuit server_name in Sentry.io crash reports and metrics # -# No default -#turn_uris = ["turn:example.turn.uri?transport=udp", "turn:example.turn.uri?transport=tcp"] +#sentry_send_server_name = false -# TURN secret to use that's read from the file path specified +# Performance monitoring/tracing sample rate for Sentry.io # -# this takes priority over "turn_secret" first, and falls back to "turn_secret" if invalid or -# failed to open. +# Note that too high values may impact performance, and can be disabled by +# setting it to 0.0 (0%) This value is read as a percentage to Sentry, +# represented as a decimal. Defaults to 15% of traces (0.15) # -# no default -#turn_secret_file = "/path/to/secret.txt" +#sentry_traces_sample_rate = 0.15 -# TURN secret to use for generating the HMAC-SHA1 hash apart of username and password generation +# Whether to attach a stacktrace to Sentry reports. # -# this is more secure, but if needed you can use traditional username/password below. +#sentry_attach_stacktrace = false + +# Send panics to sentry. This is true by default, but sentry has to be +# enabled. The global "sentry" config option must be enabled to send any +# data. # -# no default -#turn_secret = "" +#sentry_send_panic = true -# TURN username to provide the client +# Send errors to sentry. This is true by default, but sentry has to be +# enabled. This option is only effective in release-mode; forced to false +# in debug-mode. # -# no default -#turn_username = "" +#sentry_send_error = true -# TURN password to provide the client +# Controls the tracing log level for Sentry to send things like +# breadcrumbs and transactions # -# no default -#turn_password = "" +#sentry_filter = "info" -# TURN TTL +# Enable the tokio-console. This option is only relevant to developers. +# See https://conduwuit.puppyirl.gay/development.html#debugging-with-tokio-console for more information. # -# Default is 86400 seconds -#turn_ttl = 86400 +#tokio_console = false -# allow guests/unauthenticated users to access TURN credentials +# This item is undocumented. Please contribute documentation for it. # -# this is the equivalent of Synapse's `turn_allow_guests` config option. this allows -# any unauthenticated user to call `/_matrix/client/v3/voip/turnServer`. +#test = false + +# Controls whether admin room notices like account registrations, password +# changes, account deactivations, room directory publications, etc will +# be sent to the admin room. Update notices and normal admin command +# responses will still be sent. # -# defaults to false -#turn_allow_guests = false +#admin_room_notices = true +[global.tls] -# Other options not in [global]: +# Path to a valid TLS certificate file. +# +# example: "/path/to/my/certificate.crt" # +#certs = + +# Path to a valid TLS certificate private key. # -# Enables running conduwuit with direct TLS support -# It is strongly recommended you use a reverse proxy instead. This is primarily relevant for test suites like complement that require a private CA setup. -# [global.tls] -# certs = "/path/to/my/certificate.crt" -# key = "/path/to/my/private_key.key" +# example: "/path/to/my/certificate.key" # +#key = + # Whether to listen and allow for HTTP and HTTPS connections (insecure!) -# This config option is only available if conduwuit was built with `axum_dual_protocol` feature (not default feature) -# Defaults to false +# #dual_protocol = false +[global.well_known] + +# The server base domain of the URL with a specific port that the server +# well-known file will serve. This should contain a port at the end, and +# should not be a URL. +# +# example: "matrix.example.com:443" +# +#server = -# If you are using delegation via well-known files and you cannot serve them from your reverse proxy, you can -# uncomment these to serve them directly from conduwuit. This requires proxying all requests to conduwuit, not just `/_matrix` to work. +# The server URL that the client well-known file will serve. This should +# not contain a port, and should just be a valid HTTPS URL. # -#[global.well_known] -#server = "matrix.example.com:443" -#client = "https://matrix.example.com" +# example: "https://matrix.example.com" # -# A single contact and/or support page for /.well-known/matrix/support -# All options here are strings. Currently only supports 1 single contact. -# No default. +#client = + +# This item is undocumented. Please contribute documentation for it. +# +#support_page = + +# This item is undocumented. Please contribute documentation for it. +# +#support_role = + +# This item is undocumented. Please contribute documentation for it. +# +#support_email = + +# This item is undocumented. Please contribute documentation for it. # -#support_page = "" -#support_role = "" -#support_email = "" -#support_mxid = "" +#support_mxid = diff --git a/src/api/server/make_knock.rs b/src/api/server/make_knock.rs new file mode 100644 index 000000000..c1875a1f8 --- /dev/null +++ b/src/api/server/make_knock.rs @@ -0,0 +1,107 @@ +use axum::extract::State; +use conduit::Err; +use ruma::{ + api::{client::error::ErrorKind, federation::knock::create_knock_event_template}, + events::room::member::{MembershipState, RoomMemberEventContent}, + RoomVersionId, +}; +use serde_json::value::to_raw_value; +use tracing::warn; +use RoomVersionId::*; + +use crate::{service::pdu::PduBuilder, Error, Result, Ruma}; + +/// # `GET /_matrix/federation/v1/make_knock/{roomId}/{userId}` +/// +/// Creates a knock template. +pub(crate) async fn create_knock_event_template_route( + State(services): State, body: Ruma, +) -> Result { + if !services.rooms.metadata.exists(&body.room_id).await { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); + } + + if body.user_id.server_name() != body.origin() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to knock on behalf of another server/user", + )); + } + + // ACL check origin server + services + .rooms + .event_handler + .acl_check(body.origin(), &body.room_id) + .await?; + + if services + .globals + .config + .forbidden_remote_server_names + .contains(body.origin()) + { + warn!( + "Server {} for remote user {} tried knocking room ID {} which has a server name that is globally \ + forbidden. Rejecting.", + body.origin(), + &body.user_id, + &body.room_id, + ); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + + if let Some(server) = body.room_id.server_name() { + if services + .globals + .config + .forbidden_remote_server_names + .contains(&server.to_owned()) + { + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + } + + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; + + if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Room version does not support knocking.", + )); + } + + if !body.ver.contains(&room_version_id) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Your homeserver does not support the features required to knock on this room.", + )); + } + + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Knock)), + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; + + drop(state_lock); + + // room v3 and above removed the "event_id" field from remote PDU format + super::maybe_strip_event_id(&mut pdu_json, &room_version_id)?; + + Ok(create_knock_event_template::v1::Response { + room_version: room_version_id, + event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), + }) +} diff --git a/src/api/server/send_knock.rs b/src/api/server/send_knock.rs new file mode 100644 index 000000000..c57998aec --- /dev/null +++ b/src/api/server/send_knock.rs @@ -0,0 +1,190 @@ +use axum::extract::State; +use conduit::{err, pdu::gen_event_id_canonical_json, warn, Err, Error, PduEvent, Result}; +use ruma::{ + api::{client::error::ErrorKind, federation::knock::send_knock}, + events::{ + room::member::{MembershipState, RoomMemberEventContent}, + StateEventType, + }, + serde::JsonObject, + OwnedServerName, OwnedUserId, + RoomVersionId::*, +}; + +use crate::Ruma; + +/// # `PUT /_matrix/federation/v1/send_knock/{roomId}/{eventId}` +/// +/// Submits a signed knock event. +pub(crate) async fn create_knock_event_v1_route( + State(services): State, body: Ruma, +) -> Result { + if services + .globals + .config + .forbidden_remote_server_names + .contains(body.origin()) + { + warn!( + "Server {} tried knocking room ID {} who has a server name that is globally forbidden. Rejecting.", + body.origin(), + &body.room_id, + ); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + + if let Some(server) = body.room_id.server_name() { + if services + .globals + .config + .forbidden_remote_server_names + .contains(&server.to_owned()) + { + warn!( + "Server {} tried knocking room ID {} which has a server name that is globally forbidden. Rejecting.", + body.origin(), + &body.room_id, + ); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + } + + if !services.rooms.metadata.exists(&body.room_id).await { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); + } + + // ACL check origin server + services + .rooms + .event_handler + .acl_check(body.origin(), &body.room_id) + .await?; + + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; + + if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6) { + return Err!(Request(Forbidden("Room version does not support knocking."))); + } + + let Ok((event_id, value)) = gen_event_id_canonical_json(&body.pdu, &room_version_id) else { + // Event could not be converted to canonical json + return Err!(Request(InvalidParam("Could not convert event to canonical json."))); + }; + + let event_type: StateEventType = serde_json::from_value( + value + .get("type") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing type property."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event has invalid event type."))?; + + if event_type != StateEventType::RoomMember { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to send non-membership state event to knock endpoint.", + )); + } + + let content: RoomMemberEventContent = serde_json::from_value( + value + .get("content") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing content property"))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event content is empty or invalid"))?; + + if content.membership != MembershipState::Knock { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to send a non-knock membership event to knock endpoint.", + )); + } + + // ACL check sender server name + let sender: OwnedUserId = serde_json::from_value( + value + .get("sender") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing sender property."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "sender is not a valid user ID."))?; + + services + .rooms + .event_handler + .acl_check(sender.server_name(), &body.room_id) + .await?; + + // check if origin server is trying to send for another server + if sender.server_name() != body.origin() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to knock on behalf of another server.", + )); + } + + let state_key: OwnedUserId = serde_json::from_value( + value + .get("state_key") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing state_key property."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "state_key is invalid or not a user ID."))?; + + if state_key != sender { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "State key does not match sender user", + )); + }; + + let origin: OwnedServerName = serde_json::from_value( + serde_json::to_value( + value + .get("origin") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing origin property."))?, + ) + .expect("CanonicalJson is valid json value"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; + + let mut event: JsonObject = serde_json::from_str(body.pdu.get()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?; + + event.insert("event_id".to_owned(), "$placeholder".into()); + + let pdu: PduEvent = serde_json::from_value(event.into()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?; + + let mutex_lock = services + .rooms + .event_handler + .mutex_federation + .lock(&body.room_id) + .await; + + let pdu_id = services + .rooms + .event_handler + .handle_incoming_pdu(&origin, &body.room_id, &event_id, value.clone(), true) + .await? + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; + + drop(mutex_lock); + + let knock_room_state = services.rooms.state.summary_stripped(&pdu).await; + + services + .sending + .send_pdu_room(&body.room_id, &pdu_id) + .await?; + + Ok(send_knock::v1::Response { + knock_room_state, + }) +} diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index eddab2fe7..4bba14554 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -36,45 +36,51 @@ use crate::{err, error::Error, utils::sys, Result}; filename = "conduwuit-example.toml", section = "global", undocumented = "# This item is undocumented. Please contribute documentation for it.", - header = "### Conduwuit Configuration\n###\n### THIS FILE IS GENERATED. YOUR CHANGES WILL BE OVERWRITTEN!\n### \ - You should rename this file before configuring your server. Changes\n### to documentation and defaults \ - can be contributed in sourcecode at\n### src/core/config/mod.rs. This file is generated when \ - building.\n###\n", + header = "### conduwuit Configuration\n###\n### THIS FILE IS GENERATED. CHANGES/CONTRIBUTIONS IN THE REPO WILL\n### BE \ + OVERWRITTEN!\n###\n### You should rename this file before configuring your server. Changes\n### to \ + documentation and defaults can be contributed in source code at\n### src/core/config/mod.rs. This file \ + is generated when building.\n###\n### Any values pre-populated are the default values for said config \ + option.\n###\n### At the minimum, you MUST edit all the config options to your environment\n### that say \ + \"YOU NEED TO EDIT THIS\".\n### See https://conduwuit.puppyirl.gay/configuration.html for ways to\n### configure conduwuit\n", ignore = "catchall well_known tls" )] pub struct Config { /// The server_name is the pretty name of this server. It is used as a - /// suffix for user and room ids. Examples: matrix.org, conduit.rs + /// suffix for user and room IDs/aliases. /// - /// The Conduit server needs all /_matrix/ requests to be reachable at - /// https://your.server.name/ on port 443 (client-server) and 8448 (federation). + /// See the docs for reverse proxying and delegation: https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy + /// Also see the `[global.well_known]` config section at the very bottom. /// - /// If that's not possible for you, you can create /.well-known files to - /// redirect requests (delegation). See - /// https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixclient - /// and - /// https://spec.matrix.org/v1.9/server-server-api/#getwell-knownmatrixserver - /// for more information. + /// Examples of delegation: + /// - https://puppygock.gay/.well-known/matrix/server + /// - https://puppygock.gay/.well-known/matrix/client /// - /// YOU NEED TO EDIT THIS + /// YOU NEED TO EDIT THIS. THIS CANNOT BE CHANGED AFTER WITHOUT A DATABASE + /// WIPE. + /// + /// example: "conduwuit.woof" pub server_name: OwnedServerName, - /// default address (IPv4 or IPv6) conduwuit will listen on. Generally you - /// want this to be localhost (127.0.0.1 / ::1). If you are using Docker or - /// a container NAT networking setup, you likely need this to be 0.0.0.0. - /// To listen multiple addresses, specify a vector e.g. ["127.0.0.1", "::1"] - /// Default if unspecified is both IPv4 and IPv6 localhost. + /// default address (IPv4 or IPv6) conduwuit will listen on. + /// + /// If you are using Docker or a container NAT networking setup, this must + /// be "0.0.0.0". + /// + /// To listen on multiple addresses, specify a vector e.g. ["127.0.0.1", + /// "::1"] /// /// default: ["127.0.0.1", "::1"] #[serde(default = "default_address")] address: ListeningAddr, - /// The port(s) conduwuit will be running on. You need to set up a reverse - /// proxy such as Caddy or Nginx so all requests to /_matrix on port 443 - /// and 8448 will be forwarded to the conduwuit instance running on this - /// port Docker users: Don't change this, you'll need to map an external - /// port to this. To listen on multiple ports, specify a vector e.g. [8080, - /// 8448] + /// The port(s) conduwuit will be running on. + /// + /// See https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy for reverse proxying. + /// + /// Docker users: Don't change this, you'll need to map an external port to + /// this. + /// + /// To listen on multiple ports, specify a vector e.g. [8080, 8448] /// /// default: 8008 #[serde(default = "default_port")] @@ -84,108 +90,155 @@ pub struct Config { pub tls: Option, /// Uncomment unix_socket_path to listen on a UNIX socket at the specified - /// path. If listening on a UNIX socket, you must remove/comment the - /// 'address' key if defined and add your reverse proxy to the 'conduwuit' + /// path. If listening on a UNIX socket, you MUST remove/comment the + /// 'address' key if definedm AND add your reverse proxy to the 'conduwuit' /// group, unless world RW permissions are specified with unix_socket_perms /// (666 minimum). + /// + /// example: "/run/conduwuit/conduwuit.sock" pub unix_socket_path: Option, + /// The default permissions (in octal) to create the UNIX socket with. + /// /// default: 660 #[serde(default = "default_unix_socket_perms")] pub unix_socket_perms: u32, - /// Database backend: Only rocksdb is supported. - /// - /// default: rocksdb - #[serde(default = "default_database_backend")] - pub database_backend: String, - /// This is the only directory where conduwuit will save its data, including - /// media. Note: this was previously "/var/lib/matrix-conduit" + /// media. + /// Note: this was previously "/var/lib/matrix-conduit" + /// + /// YOU NEED TO EDIT THIS. + /// + /// example: "/var/lib/conduwuit" pub database_path: PathBuf, + /// conduwuit supports online database backups using RocksDB's Backup engine + /// API. To use this, set a database backup path that conduwuit can write + /// to. + /// + /// See https://conduwuit.puppyirl.gay/maintenance.html#backups for more information. + /// + /// example: "/opt/conduwuit-db-backups" pub database_backup_path: Option, + /// The amount of online RocksDB database backups to keep/retain, if using + /// "database_backup_path", before deleting the oldest one. + /// + /// default: 1 #[serde(default = "default_database_backups_to_keep")] pub database_backups_to_keep: i16, /// Set this to any float value in megabytes for conduwuit to tell the /// database engine that this much memory is available for database-related - /// caches. May be useful if you have significant memory to spare to - /// increase performance. + /// caches. + /// + /// May be useful if you have significant memory to spare to increase + /// performance. + /// + /// Similar to the individual LRU caches, this is scaled up with your CPU + /// core count. /// - /// default: 256.0 + /// This defaults to 128.0 + (64.0 * CPU core count) #[serde(default = "default_db_cache_capacity_mb")] pub db_cache_capacity_mb: f64, /// Option to control adding arbitrary text to the end of the user's /// displayname upon registration with a space before the text. This was the /// lightning bolt emoji option, just replaced with support for adding your - /// own custom text or emojis. To disable, set this to "" (an empty string) - /// Defaults to "🏳️⚧️" (trans pride flag) + /// own custom text or emojis. To disable, set this to "" (an empty string). /// - /// default: 🏳️⚧️ + /// The default is the trans pride flag. + /// + /// example: "🏳️⚧️" + /// + /// default: "🏳️⚧️" #[serde(default = "default_new_user_displayname_suffix")] pub new_user_displayname_suffix: String, /// If enabled, conduwuit will send a simple GET request periodically to /// `https://pupbrain.dev/check-for-updates/stable` for any new /// announcements made. Despite the name, this is not an update check - /// endpoint, it is simply an announcement check endpoint. Defaults to - /// false. + /// endpoint, it is simply an announcement check endpoint. + /// + /// This is disabled by default as this is rarely used except for security + /// updates or major updates. #[serde(default)] pub allow_check_for_updates: bool, - #[serde(default = "default_pdu_cache_capacity")] - pub pdu_cache_capacity: u32, - /// Set this to any float value to multiply conduwuit's in-memory LRU caches - /// with. May be useful if you have significant memory to spare to increase + /// with such as "auth_chain_cache_capacity". + /// + /// May be useful if you have significant memory to spare to increase /// performance. This was previously called /// `conduit_cache_capacity_modifier`. /// - /// default: 1.0. + /// If you have low memory, reducing this may be viable. + /// + /// By default, the individual caches such as "auth_chain_cache_capacity" + /// are scaled by your CPU core count. + /// + /// default: 1.0 #[serde(default = "default_cache_capacity_modifier", alias = "conduit_cache_capacity_modifier")] pub cache_capacity_modifier: f64, + /// default: varies by system + #[serde(default = "default_pdu_cache_capacity")] + pub pdu_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_auth_chain_cache_capacity")] pub auth_chain_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_shorteventid_cache_capacity")] pub shorteventid_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_eventidshort_cache_capacity")] pub eventidshort_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_shortstatekey_cache_capacity")] pub shortstatekey_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_statekeyshort_cache_capacity")] pub statekeyshort_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_server_visibility_cache_capacity")] pub server_visibility_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_user_visibility_cache_capacity")] pub user_visibility_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_stateinfo_cache_capacity")] pub stateinfo_cache_capacity: u32, + /// default: varies by system #[serde(default = "default_roomid_spacehierarchy_cache_capacity")] pub roomid_spacehierarchy_cache_capacity: u32, /// Maximum entries stored in DNS memory-cache. The size of an entry may /// vary so please take care if raising this value excessively. Only /// decrease this when using an external DNS cache. Please note - /// that systemd does *not* count as an external cache, even when configured - /// to do so. + /// that systemd-resolved does *not* count as an external cache, even when + /// configured to do so. + /// + /// default: 32768 #[serde(default = "default_dns_cache_entries")] pub dns_cache_entries: u32, /// Minimum time-to-live in seconds for entries in the DNS cache. The - /// default may appear high to most administrators; this is by design. Only - /// decrease this if you are using an external DNS cache. + /// default may appear high to most administrators; this is by design as the + /// majority of NXDOMAINs are correct for a long time (e.g. the server is no + /// longer running Matrix). Only decrease this if you are using an external + /// DNS cache. + /// + /// default_dns_min_ttl: 259200 #[serde(default = "default_dns_min_ttl")] pub dns_min_ttl: u64, @@ -195,16 +248,23 @@ pub struct Config { /// and aggressively cached rather than constantly rechecked. /// /// Defaults to 3 days as these are *very rarely* false negatives. + /// + /// default: 259200 #[serde(default = "default_dns_min_ttl_nxdomain")] pub dns_min_ttl_nxdomain: u64, /// Number of retries after a timeout. + /// + /// default: 10 #[serde(default = "default_dns_attempts")] pub dns_attempts: u16, /// The number of seconds to wait for a reply to a DNS query. Please note /// that recursive queries can take up to several seconds for some domains, - /// so this value should not be too low. + /// so this value should not be too low, especially on slower hardware or + /// resolvers. + /// + /// default: 10 #[serde(default = "default_dns_timeout")] pub dns_timeout: u64, @@ -223,8 +283,7 @@ pub struct Config { /// Enables using *only* TCP for querying your specified nameservers instead /// of UDP. /// - /// You very likely do *not* want this. hickory-resolver already falls back - /// to TCP on UDP errors. Defaults to false + /// If you are running conduwuit in a container environment, this config option may need to be enabled. See https://conduwuit.puppyirl.gay/troubleshooting.html#potential-dns-issues-when-using-docker for more details. #[serde(default)] pub query_over_tcp_only: bool, @@ -232,30 +291,34 @@ pub struct Config { /// /// Takes a number of one of the following options: /// 1 - Ipv4Only (Only query for A records, no AAAA/IPv6) + /// /// 2 - Ipv6Only (Only query for AAAA records, no A/IPv4) + /// /// 3 - Ipv4AndIpv6 (Query for A and AAAA records in parallel, uses whatever - /// returns a successful response first) 4 - Ipv6thenIpv4 (Query for AAAA - /// record, if that fails then query the A record) 5 - Ipv4thenIpv6 (Query - /// for A record, if that fails then query the AAAA record) + /// returns a successful response first) + /// + /// 4 - Ipv6thenIpv4 (Query for AAAA record, if that fails then query the A + /// record) /// - /// If you don't have IPv6 networking, then for better performance it may be - /// suitable to set this to Ipv4Only (1) as you will never ever use the - /// AAAA record contents even if the AAAA record is successful instead of - /// the A record. + /// 5 - Ipv4thenIpv6 (Query for A record, if that fails then query the AAAA + /// record) /// - /// Defaults to 5 - Ipv4ThenIpv6 as this is the most compatible and IPv4 - /// networking is currently the most prevalent. + /// If you don't have IPv6 networking, then for better DNS performance it + /// may be suitable to set this to Ipv4Only (1) as you will never ever use + /// the AAAA record contents even if the AAAA record is successful instead + /// of the A record. /// /// default: 5 #[serde(default = "default_ip_lookup_strategy")] pub ip_lookup_strategy: u8, - /// Max request size for file uploads + /// Max request size for file uploads in bytes. Defaults to 20MB. /// /// default: 20971520 #[serde(default = "default_max_request_size")] pub max_request_size: usize, + /// default: 192 #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, @@ -365,7 +428,7 @@ pub struct Config { /// Notification gateway pusher idle connection pool timeout /// - /// Defaults to 15 seconds + /// default: 15 #[serde(default = "default_pusher_idle_timeout")] pub pusher_idle_timeout: u64, @@ -373,7 +436,7 @@ pub struct Config { /// server. /// /// If set to true without a token configured, users can register with no - /// form of 2nd- step only if you set + /// form of 2nd-step only if you set /// `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to /// true in your config. /// @@ -387,21 +450,27 @@ pub struct Config { /// A static registration token that new users will have to provide when /// creating an account. If unset and `allow_registration` is true, - /// registration is open without any condition. YOU NEED TO EDIT THIS. + /// registration is open without any condition. + /// + /// YOU NEED TO EDIT THIS OR USE registration_token_file. + /// + /// example: "o&^uCtes4HPf0Vu@F20jQeeWE7" pub registration_token: Option, - /// Path to a file on the system that gets read for the registration token + /// Path to a file on the system that gets read for the registration token. + /// this config option takes precedence/priority over "registration_token". /// /// conduwuit must be able to access the file, and it must not be empty /// - /// no default + /// example: "/etc/conduwuit/.reg_token" pub registration_token_file: Option, /// Controls whether encrypted rooms and events are allowed. #[serde(default = "true_fn")] pub allow_encryption: bool, - /// Controls whether federation is allowed or not. + /// Controls whether federation is allowed or not. It is not recommended to + /// disable this after the fact due to potential federation breakage. #[serde(default = "true_fn")] pub allow_federation: bool, @@ -433,25 +502,25 @@ pub struct Config { /// allow guests/unauthenticated users to access TURN credentials /// /// this is the equivalent of Synapse's `turn_allow_guests` config option. - /// this allows any unauthenticated user to call + /// this allows any unauthenticated user to call the endpoint /// `/_matrix/client/v3/voip/turnServer`. /// - /// defaults to false + /// It is unlikely you need to enable this as all major clients support + /// authentication for this endpoint and prevents misuse of your TURN server + /// from potential bots. #[serde(default)] pub turn_allow_guests: bool, /// Set this to true to lock down your server's public room directory and /// only allow admins to publish rooms to the room directory. Unpublishing /// is still allowed by all users with this enabled. - /// - /// Defaults to false #[serde(default)] pub lockdown_public_room_directory: bool, /// Set this to true to allow federating device display names / allow /// external users to see your device display name. If federation is /// disabled entirely (`allow_federation`), this is inherently false. For - /// privacy, this is best disabled. + /// privacy reasons, this is best left disabled. #[serde(default)] pub allow_device_name_federation: bool, @@ -464,25 +533,29 @@ pub struct Config { /// try to invite you to a DM or room. Also can protect against profile /// spiders. /// - /// Defaults to true. - /// /// This is inherently false if `allow_federation` is disabled #[serde(default = "true_fn", alias = "allow_profile_lookup_federation_requests")] pub allow_inbound_profile_lookup_federation_requests: bool, - /// controls whether users are allowed to create rooms. - /// appservices and admins are always allowed to create rooms - /// defaults to true + /// controls whether standard users are allowed to create rooms. appservices + /// and admins are always allowed to create rooms #[serde(default = "true_fn")] pub allow_room_creation: bool, /// Set to false to disable users from joining or creating room versions /// that aren't 100% officially supported by conduwuit. - /// conduwuit officially supports room versions 6 - 10. conduwuit has - /// experimental/unstable support for 3 - 5, and 11. Defaults to true. + /// + /// conduwuit officially supports room versions 6 - 11. + /// + /// conduwuit has slightly experimental (though works fine in practice) + /// support for versions 3 - 5 #[serde(default = "true_fn")] pub allow_unstable_room_versions: bool, + /// default room version conduwuit will create rooms with. + /// + /// per spec, room version 10 is the default. + /// /// default: 10 #[serde(default = "default_default_room_version")] pub default_room_version: RoomVersionId, @@ -498,10 +571,12 @@ pub struct Config { #[serde(default = "default_jaeger_filter")] pub jaeger_filter: String, - /// If the 'perf_measurements' feature is enabled, enables collecting folded - /// stack trace profile of tracing spans using tracing_flame. The resulting - /// profile can be visualized with inferno[1], speedscope[2], or a number of - /// other tools. [1]: https://github.com/jonhoo/inferno + /// If the 'perf_measurements' compile-time feature is enabled, enables + /// collecting folded stack trace profile of tracing spans using + /// tracing_flame. The resulting profile can be visualized with inferno[1], + /// speedscope[2], or a number of other tools. + /// + /// [1]: https://github.com/jonhoo/inferno /// [2]: www.speedscope.app #[serde(default)] pub tracing_flame: bool, @@ -546,8 +621,10 @@ pub struct Config { /// Servers listed here will be used to gather public keys of other servers /// (notary trusted key servers). /// - /// (Currently, conduwuit doesn't support batched key requests, so this list - /// should only contain other Synapse servers) Defaults to `matrix.org` + /// Currently, conduwuit doesn't support inbound batched key requests, so + /// this list should only contain other Synapse servers + /// + /// example: ["matrix.org", "constellatory.net", "tchncs.de"] /// /// default: ["matrix.org"] #[serde(default = "default_trusted_servers")] @@ -556,9 +633,10 @@ pub struct Config { /// Whether to query the servers listed in trusted_servers first or query /// the origin server first. For best security, querying the origin server /// first is advised to minimize the exposure to a compromised trusted - /// server. For maximum performance this can be set to true, however other - /// options exist to query trusted servers first under specific high-load - /// circumstances and should be evaluated before setting this to true. + /// server. For maximum federation/join performance this can be set to true, + /// however other options exist to query trusted servers first under + /// specific high-load circumstances and should be evaluated before setting + /// this to true. #[serde(default)] pub query_trusted_key_servers_first: bool, @@ -582,7 +660,7 @@ pub struct Config { #[serde(default)] pub only_query_trusted_key_servers: bool, - /// Maximum number of keys to request in each trusted server query. + /// Maximum number of keys to request in each trusted server batch query. /// /// default: 1024 #[serde(default = "default_trusted_server_batch_size")] @@ -590,6 +668,7 @@ pub struct Config { /// max log level for conduwuit. allows debug, info, warn, or error /// see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives + /// /// **Caveat**: /// For release builds, the tracing crate is configured to only implement /// levels higher than error to avoid unnecessary overhead in the compiled @@ -601,8 +680,6 @@ pub struct Config { pub log: String, /// controls whether logs will be outputted with ANSI colours - /// - /// default: true #[serde(default = "true_fn", alias = "log_colours")] pub log_colors: bool, @@ -615,40 +692,43 @@ pub struct Config { /// OpenID token expiration/TTL in seconds /// /// These are the OpenID tokens that are primarily used for Matrix account - /// integrations, *not* OIDC/OpenID Connect/etc + /// integrations (e.g. Vector Integrations in Element), *not* OIDC/OpenID + /// Connect/etc /// /// default: 3600 #[serde(default = "default_openid_token_ttl")] pub openid_token_ttl: u64, - /// TURN username to provide the client - /// - /// no default + /// static TURN username to provide the client if not using a shared secret + /// ("turn_secret"), It is recommended to use a shared secret over static + /// credentials. #[serde(default)] pub turn_username: String, - /// TURN password to provide the client - /// - /// no default + /// static TURN password to provide the client if not using a shared secret + /// ("turn_secret"). It is recommended to use a shared secret over static + /// credentials. #[serde(default)] pub turn_password: String, /// vector list of TURN URIs/servers to use /// /// replace "example.turn.uri" with your TURN domain, such as the coturn - /// "realm". if using TURN over TLS, replace "turn:" with "turns:" + /// "realm" config option. if using TURN over TLS, replace the URI prefix + /// "turn:" with "turns:" /// - /// No default - #[serde(default = "Vec::new")] + /// example: ["turn:example.turn.uri?transport=udp", + /// "turn:example.turn.uri?transport=tcp"] + /// + /// default: [] + #[serde(default)] pub turn_uris: Vec, /// TURN secret to use for generating the HMAC-SHA1 hash apart of username /// and password generation /// /// this is more secure, but if needed you can use traditional - /// username/password below. - /// - /// no default + /// static username/password credentials. #[serde(default)] pub turn_secret: String, @@ -657,7 +737,7 @@ pub struct Config { /// this takes priority over "turn_secret" first, and falls back to /// "turn_secret" if invalid or failed to open. /// - /// no default + /// example: "/etc/conduwuit/.turn_secret" pub turn_secret_file: Option, /// TURN TTL in seconds @@ -670,7 +750,10 @@ pub struct Config { /// registered users join. The rooms specified must be rooms that you /// have joined at least once on the server, and must be public. /// - /// No default. + /// example: ["#conduwuit:puppygock.gay", + /// "!eoIzvAvVwY23LPDay8:puppygock.gay"] + /// + /// default: [] #[serde(default = "Vec::new")] pub auto_join_rooms: Vec, @@ -695,15 +778,18 @@ pub struct Config { /// RocksDB log level. This is not the same as conduwuit's log level. This /// is the log level for the RocksDB engine/library which show up in your - /// database folder/path as `LOG` files. Defaults to error. conduwuit will - /// typically log RocksDB errors as normal. + /// database folder/path as `LOG` files. conduwuit will log RocksDB errors + /// as normal through tracing. + /// + /// default: "error" #[serde(default = "default_rocksdb_log_level")] pub rocksdb_log_level: String, #[serde(default)] pub rocksdb_log_stderr: bool, - /// Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB. + /// Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB in + /// bytes. /// /// default: 4194304 #[serde(default = "default_rocksdb_max_log_file_size")] @@ -727,13 +813,19 @@ pub struct Config { /// but it cannot account for all setups. If you experience any weird /// RocksDB issues, try enabling this option as it turns off Direct IO and /// feel free to report in the conduwuit Matrix room if this option fixes - /// your DB issues. See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. + /// your DB issues. + /// + /// See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. #[serde(default)] pub rocksdb_optimize_for_spinning_disks: bool, - /// Enables direct-io to increase database performance. This is enabled by - /// default. Set this option to false if the database resides on a - /// filesystem which does not support direct-io. + /// Enables direct-io to increase database performance via unbuffered I/O. + /// + /// See https://github.com/facebook/rocksdb/wiki/Direct-IO for more details about Direct IO and RocksDB. + /// + /// Set this option to false if the database resides on a filesystem which + /// does not support direct-io like FUSE, or any form of complex filesystem + /// setup such as possibly ZFS. #[serde(default = "true_fn")] pub rocksdb_direct_io: bool, @@ -746,14 +838,17 @@ pub struct Config { pub rocksdb_parallelism_threads: usize, /// Maximum number of LOG files RocksDB will keep. This must *not* be set to - /// 0. It must be at least 1. Defaults to 3 as these are not very useful. + /// 0. It must be at least 1. Defaults to 3 as these are not very useful + /// unless troubleshooting/debugging a RocksDB bug. /// /// default: 3 #[serde(default = "default_rocksdb_max_log_files")] pub rocksdb_max_log_files: usize, /// Type of RocksDB database compression to use. + /// /// Available options are "zstd", "zlib", "bz2", "lz4", or "none" + /// /// It is best to use ZSTD as an overall good balance between /// speed/performance, storage, IO amplification, and CPU usage. /// For more performance but less compression (more storage used) and less @@ -766,10 +861,14 @@ pub struct Config { pub rocksdb_compression_algo: String, /// Level of compression the specified compression algorithm for RocksDB to - /// use. Default is 32767, which is internally read by RocksDB as the + /// use. + /// + /// Default is 32767, which is internally read by RocksDB as the /// default magic number and translated to the library's default /// compression level as they all differ. /// See their `kDefaultCompressionLevel`. + /// + /// default: 32767 #[serde(default = "default_rocksdb_compression_level")] pub rocksdb_compression_level: i32, @@ -783,15 +882,19 @@ pub struct Config { /// it may be desirable to have a very high compression level here as it's /// lesss likely for this data to be used. Research your chosen compression /// algorithm. + /// + /// default: 32767 #[serde(default = "default_rocksdb_bottommost_compression_level")] pub rocksdb_bottommost_compression_level: i32, - /// Whether to enable RocksDB "bottommost_compression". + /// Whether to enable RocksDB's "bottommost_compression". + /// /// At the expense of more CPU usage, this will further compress the /// database to reduce more storage. It is recommended to use ZSTD - /// compression with this for best compression results. See https://github.com/facebook/rocksdb/wiki/Compression for more details. + /// compression with this for best compression results. This may be useful + /// if you're trying to reduce storage usage from the database. /// - /// Defaults to false as this uses more CPU when compressing. + /// See https://github.com/facebook/rocksdb/wiki/Compression for more details. #[serde(default)] pub rocksdb_bottommost_compression: bool, @@ -822,9 +925,9 @@ pub struct Config { /// 2 = PointInTime (use me if trying to recover) /// 3 = SkipAnyCorruptedRecord (you now voided your Conduwuit warranty) /// - /// See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information + /// See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information on these modes. /// - /// Defaults to 1 (TolerateCorruptedTailRecords) + /// See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. /// /// default: 1 #[serde(default = "default_rocksdb_recovery_mode")] @@ -841,6 +944,8 @@ pub struct Config { /// repair. /// - Disabling repair mode and restarting the server is recommended after /// running the repair. + /// + /// See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. #[serde(default)] pub rocksdb_repair: bool, @@ -862,6 +967,14 @@ pub struct Config { #[serde(default = "true_fn")] pub rocksdb_compaction_ioprio_idle: bool, + /// Config option to disable RocksDB compaction. You should never ever have + /// to disable this. If you for some reason find yourself needing to disable + /// this as part of troubleshooting or a bug, please reach out to us in the + /// conduwuit Matrix room with information and details. + /// + /// Disabling compaction will lead to a significantly bloated and + /// explosively large database, gradually poor performance, unnecessarily + /// excessive disk read/writes, and slower shutdowns and startups. #[serde(default = "true_fn")] pub rocksdb_compaction: bool, @@ -876,33 +989,45 @@ pub struct Config { /// 3 to 5 = Statistics with possible performance impact. /// 6 = All statistics. /// - /// Defaults to 1 (No statistics, except in debug-mode) - /// /// default: 1 #[serde(default = "default_rocksdb_stats_level")] pub rocksdb_stats_level: u8, + /// This is a password that can be configured that will let you login to the + /// server bot account (currently `@conduit`) for emergency troubleshooting + /// purposes such as recovering/recreating your admin room, or inviting + /// yourself back. + /// + /// See https://conduwuit.puppyirl.gay/troubleshooting.html#lost-access-to-admin-room for other ways to get back into your admin room. + /// + /// Once this password is unset, all sessions will be logged out for + /// security purposes. + /// + /// example: "F670$2CP@Hw8mG7RY1$%!#Ic7YA" pub emergency_password: Option, + /// default: "/_matrix/push/v1/notify" #[serde(default = "default_notification_push_path")] pub notification_push_path: String, /// Config option to control local (your server only) presence - /// updates/requests. Defaults to true. Note that presence on conduwuit is + /// updates/requests. Note that presence on conduwuit is /// very fast unlike Synapse's. If using outgoing presence, this MUST be /// enabled. #[serde(default = "true_fn")] pub allow_local_presence: bool, /// Config option to control incoming federated presence updates/requests. - /// Defaults to true. This option receives presence updates from other + /// + /// This option receives presence updates from other /// servers, but does not send any unless `allow_outgoing_presence` is true. /// Note that presence on conduwuit is very fast unlike Synapse's. #[serde(default = "true_fn")] pub allow_incoming_presence: bool, - /// Config option to control outgoing presence updates/requests. Defaults to - /// true. This option sends presence updates to other servers, but does not + /// Config option to control outgoing presence updates/requests. + /// + /// This option sends presence updates to other servers, but does not /// receive any unless `allow_incoming_presence` is true. /// Note that presence on conduwuit is very fast unlike Synapse's. /// If using outgoing presence, you MUST enable `allow_local_presence` as @@ -986,6 +1111,9 @@ pub struct Config { /// compression may weaken TLS. Most users should not need to enable this. /// See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before /// deciding to enable this. + /// + /// If you are in a large amount of rooms, you may find that enabling this + /// is necessary to reduce the significantly large response bodies. #[serde(default)] pub gzip_compression: bool, @@ -998,18 +1126,18 @@ pub struct Config { #[serde(default)] pub brotli_compression: bool, - /// Set to true to allow user type "guest" registrations. Element attempts - /// to register guest users automatically. Defaults to false. + /// Set to true to allow user type "guest" registrations. Some clients like + /// Element attempt to register guest users automatically. #[serde(default)] pub allow_guest_registration: bool, - /// Set to true to log guest registrations in the admin room. - /// Defaults to false as it may be noisy or unnecessary. + /// Set to true to log guest registrations in the admin room. Note that + /// these may be noisy or unnecessary if you're a public homeserver. #[serde(default)] pub log_guest_registrations: bool, /// Set to true to allow guest registrations/users to auto join any rooms - /// specified in `auto_join_rooms` Defaults to false. + /// specified in `auto_join_rooms`. #[serde(default)] pub allow_guests_auto_join_rooms: bool, @@ -1033,7 +1161,7 @@ pub struct Config { /// Checks consistency of the media directory at startup: /// 1. When `media_compat_file_link` is enbled, this check will upgrade - /// media when switching back and forth between Conduit and Conduwuit. + /// media when switching back and forth between Conduit and conduwuit. /// Both options must be enabled to handle this. /// 2. When media is deleted from the directory, this check will also delete /// its database entry. @@ -1041,8 +1169,6 @@ pub struct Config { /// If none of these checks apply to your use cases, and your media /// directory is significantly large setting this to false may reduce /// startup time. - /// - /// Enabled by default. #[serde(default = "true_fn")] pub media_startup_check: bool, @@ -1051,9 +1177,9 @@ pub struct Config { /// Conduit again. Otherwise setting this to false reduces filesystem /// clutter and overhead for managing these symlinks in the directory. This /// is now disabled by default. You may still return to upstream Conduit - /// but you have to run Conduwuit at least once with this set to true and + /// but you have to run conduwuit at least once with this set to true and /// allow the media_startup_check to take place before shutting - /// down to return to Conduit. Disabled by default. + /// down to return to Conduit. #[serde(default)] pub media_compat_file_link: bool, @@ -1061,14 +1187,16 @@ pub struct Config { /// checks. This means if you delete files from the media directory the /// corresponding entries will be removed from the database. This is /// disabled by default because if the media directory is accidentally moved - /// or inaccessible the metadata entries in the database will be lost with - /// sadness. Disabled by default. + /// or inaccessible, the metadata entries in the database will be lost with + /// sadness. #[serde(default)] pub prune_missing_media: bool, /// Vector list of servers that conduwuit will refuse to download remote - /// media from. No default. - #[serde(default = "HashSet::new")] + /// media from. + /// + /// default: [] + #[serde(default)] pub prevent_media_downloads_from: HashSet, /// List of forbidden server names that we will block incoming AND outgoing @@ -1078,13 +1206,17 @@ pub struct Config { /// sender user's server name, inbound federation X-Matrix origin, and /// outbound federation handler. /// - /// Basically "global" ACLs. No default. - #[serde(default = "HashSet::new")] + /// Basically "global" ACLs. + /// + /// default: [] + #[serde(default)] pub forbidden_remote_server_names: HashSet, /// List of forbidden server names that we will block all outgoing federated /// room directory requests for. Useful for preventing our users from - /// wandering into bad servers or spaces. No default. + /// wandering into bad servers or spaces. + /// + /// default: [] #[serde(default = "HashSet::new")] pub forbidden_remote_room_directory_server_names: HashSet, @@ -1100,28 +1232,12 @@ pub struct Config { /// Currently this does not account for proxies in use like Synapse does. /// /// To disable, set this to be an empty vector (`[]`). - /// The default is: - /// [ - /// "127.0.0.0/8", - /// "10.0.0.0/8", - /// "172.16.0.0/12", - /// "192.168.0.0/16", - /// "100.64.0.0/10", - /// "192.0.0.0/24", - /// "169.254.0.0/16", - /// "192.88.99.0/24", - /// "198.18.0.0/15", - /// "192.0.2.0/24", - /// "198.51.100.0/24", - /// "203.0.113.0/24", - /// "224.0.0.0/4", - /// "::1/128", - /// "fe80::/10", - /// "fc00::/7", - /// "2001:db8::/32", - /// "ff00::/8", - /// "fec0::/10", - /// ] + /// + /// default: ["127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", + /// "192.168.0.0/16", "100.64.0.0/10", "192.0.0.0/24", "169.254.0.0/16", + /// "192.88.99.0/24", "198.18.0.0/15", "192.0.2.0/24", "198.51.100.0/24", + /// "203.0.113.0/24", "224.0.0.0/4", "::1/128", "fe80::/10", "fc00::/7", + /// "2001:db8::/32", "ff00::/8", "fec0::/10"] #[serde(default = "default_ip_range_denylist")] pub ip_range_denylist: Vec, @@ -1132,7 +1248,9 @@ pub struct Config { /// allow all URL previews. Please note that this opens up significant /// attack surface to your server, you are expected to be aware of the /// risks by doing so. - #[serde(default = "Vec::new")] + /// + /// default: [] + #[serde(default)] pub url_preview_domain_contains_allowlist: Vec, /// Vector list of explicit domains allowed to send requests to for URL @@ -1143,7 +1261,9 @@ pub struct Config { /// allow all URL previews. Please note that this opens up significant /// attack surface to your server, you are expected to be aware of the /// risks by doing so. - #[serde(default = "Vec::new")] + /// + /// default: [] + #[serde(default)] pub url_preview_domain_explicit_allowlist: Vec, /// Vector list of explicit domains not allowed to send requests to for URL @@ -1152,7 +1272,9 @@ pub struct Config { /// "http://google.com", but not /// "https://mymaliciousdomainexamplegoogle.com". The denylist is checked /// first before allowlist. Setting this to "*" will not do anything. - #[serde(default = "Vec::new")] + /// + /// default: [] + #[serde(default)] pub url_preview_domain_explicit_denylist: Vec, /// Vector list of URLs allowed to send requests to for URL previews. @@ -1164,19 +1286,23 @@ pub struct Config { /// "*" will allow all URL previews. Please note that this opens up /// significant attack surface to your server, you are expected to be /// aware of the risks by doing so. - #[serde(default = "Vec::new")] + /// + /// default: [] + #[serde(default)] pub url_preview_url_contains_allowlist: Vec, /// Maximum amount of bytes allowed in a URL preview body size when - /// spidering. Defaults to 384KB. + /// spidering. Defaults to 384KB in bytes. /// - /// defaukt: 384000 + /// default: 384000 #[serde(default = "default_url_preview_max_spider_size")] pub url_preview_max_spider_size: usize, /// Option to decide whether you would like to run the domain allowlist /// checks (contains and explicit) on the root domain or not. Does not apply - /// to URL contains allowlist. Defaults to false. Example: If this is + /// to URL contains allowlist. Defaults to false. + /// + /// Example usecase: If this is /// enabled and you have "wikipedia.org" allowed in the explicit and/or /// contains domain allowlist, it will allow all subdomains under /// "wikipedia.org" such as "en.m.wikipedia.org" as the root domain is @@ -1186,21 +1312,36 @@ pub struct Config { #[serde(default)] pub url_preview_check_root_domain: bool, - /// List of forbidden room aliases and room IDs as patterns/strings. Values - /// in this list are matched as *contains*. This is checked upon room alias - /// creation, custom room ID creation if used, and startup as warnings if - /// any room aliases in your database have a forbidden room alias/ID. - /// No default. - #[serde(default = "RegexSet::empty")] + /// List of forbidden room aliases and room IDs as strings of regex + /// patterns. + /// + /// Regex can be used or explicit contains matches can be done by + /// just specifying the words (see example). + /// + /// This is checked upon room alias creation, custom room ID creation if + /// used, and startup as warnings if any room aliases in your database have + /// a forbidden room alias/ID. + /// + /// example: ["19dollarfortnitecards", "b[4a]droom"] + /// + /// default: [] + #[serde(default)] #[serde(with = "serde_regex")] pub forbidden_alias_names: RegexSet, - /// List of forbidden username patterns/strings. Values in this list are - /// matched as *contains*. This is checked upon username availability - /// check, registration, and startup as warnings if any local users in your - /// database have a forbidden username. - /// No default. - #[serde(default = "RegexSet::empty")] + /// List of forbidden username patterns/strings. + /// + /// Regex can be used or explicit contains matches can be done by just + /// specifying the words (see example). + /// + /// This is checked upon username availability check, registration, and + /// startup as warnings if any local users in your database have a forbidden + /// username. + /// + /// example: ["administrator", "b[a4]dusernam[3e]"] + /// + /// default: [] + #[serde(default)] #[serde(with = "serde_regex")] pub forbidden_usernames: RegexSet, @@ -1231,9 +1372,12 @@ pub struct Config { #[serde(default)] pub block_non_admin_invites: bool, - /// Allows admins to enter commands in rooms other than #admins by prefixing - /// with \!admin. The reply will be publicly visible to the room, - /// originating from the sender. + /// Allows admins to enter commands in rooms other than "#admins" (admin + /// room) by prefixing your message with "\!admin" or "\\!admin" followed + /// up a normal conduwuit admin command. The reply will be publicly visible + /// to the room, originating from the sender. + /// + /// example: \\!admin debug ping puppygock.gay #[serde(default = "true_fn")] pub admin_escape_commands: bool, @@ -1246,8 +1390,6 @@ pub struct Config { /// Controls what admin commands will be executed on startup. This is a /// vector list of strings of admin commands to run. /// - /// An example of this can be: `admin_execute = ["debug ping puppygock.gay", - /// "debug echo hi"]` /// /// This option can also be configured with the `--execute` conduwuit /// argument and can take standard shell commands and environment variables @@ -1255,6 +1397,8 @@ pub struct Config { /// Such example could be: `./conduwuit --execute "server admin-notice /// conduwuit has started up at $(date)"` /// + /// example: admin_execute = ["debug ping puppygock.gay", "debug echo hi"]` + /// /// default: [] #[serde(default)] pub admin_execute: Vec, @@ -1272,6 +1416,14 @@ pub struct Config { #[serde(default = "default_admin_log_capture")] pub admin_log_capture: String, + /// The default room tag to apply on the admin room. + /// + /// On some clients like Element, the room tag "m.server_notice" is a + /// special pinned room at the very bottom of your room list. The conduwuit + /// admin room can be pinned here so you always have an easy-to-access + /// shortcut dedicated to your admin room. + /// + /// default: "m.server_notice" #[serde(default = "default_admin_room_tag")] pub admin_room_tag: String, @@ -1283,12 +1435,11 @@ pub struct Config { /// Sentry reporting URL if a custom one is desired /// - /// Defaults to conduwuit's default Sentry endpoint: - /// "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" + /// default: "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" #[serde(default = "default_sentry_endpoint")] pub sentry_endpoint: Option, - /// Report your Conduwuit server_name in Sentry.io crash reports and metrics + /// Report your conduwuit server_name in Sentry.io crash reports and metrics #[serde(default)] pub sentry_send_server_name: bool, @@ -1307,7 +1458,8 @@ pub struct Config { pub sentry_attach_stacktrace: bool, /// Send panics to sentry. This is true by default, but sentry has to be - /// enabled. + /// enabled. The global "sentry" config option must be enabled to send any + /// data. #[serde(default = "true_fn")] pub sentry_send_panic: bool, @@ -1318,13 +1470,14 @@ pub struct Config { pub sentry_send_error: bool, /// Controls the tracing log level for Sentry to send things like - /// breadcrumbs and transactions Defaults to "info" + /// breadcrumbs and transactions + /// + /// default: "info" #[serde(default = "default_sentry_filter")] pub sentry_filter: String, /// Enable the tokio-console. This option is only relevant to developers. - /// See: docs/development.md#debugging-with-tokio-console for more - /// information. + /// See https://conduwuit.puppyirl.gay/development.html#debugging-with-tokio-console for more information. #[serde(default)] pub tokio_console: bool, @@ -1346,18 +1499,33 @@ pub struct Config { #[derive(Clone, Debug, Deserialize)] #[config_example_generator(filename = "conduwuit-example.toml", section = "global.tls")] pub struct TlsConfig { + /// Path to a valid TLS certificate file. + /// + /// example: "/path/to/my/certificate.crt" pub certs: String, + /// Path to a valid TLS certificate private key. + /// + /// example: "/path/to/my/certificate.key" pub key: String, - #[serde(default)] /// Whether to listen and allow for HTTP and HTTPS connections (insecure!) + #[serde(default)] pub dual_protocol: bool, } #[derive(Clone, Debug, Deserialize, Default)] #[config_example_generator(filename = "conduwuit-example.toml", section = "global.well_known")] pub struct WellKnownConfig { - pub client: Option, + /// The server base domain of the URL with a specific port that the server + /// well-known file will serve. This should contain a port at the end, and + /// should not be a URL. + /// + /// example: "matrix.example.com:443" pub server: Option, + /// The server URL that the client well-known file will serve. This should + /// not contain a port, and should just be a valid HTTPS URL. + /// + /// example: "https://matrix.example.com" + pub client: Option, pub support_page: Option, pub support_role: Option, pub support_email: Option, @@ -1460,7 +1628,6 @@ impl fmt::Display for Config { }; line("Server name", self.server_name.host()); - line("Database backend", &self.database_backend); line("Database path", &self.database_path.to_string_lossy()); line( "Database backup path", @@ -1861,8 +2028,6 @@ fn default_unix_socket_perms() -> u32 { 660 } fn default_database_backups_to_keep() -> i16 { 1 } -fn default_database_backend() -> String { "rocksdb".to_owned() } - fn default_db_cache_capacity_mb() -> f64 { 128.0 + parallelism_scaled_f64(64.0) } fn default_pdu_cache_capacity() -> u32 { parallelism_scaled_u32(10_000).saturating_add(100_000) } diff --git a/src/service/migrations.rs b/src/service/migrations.rs index d6c342f86..4c821fa38 100644 --- a/src/service/migrations.rs +++ b/src/service/migrations.rs @@ -59,7 +59,6 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> { async fn fresh(services: &Services) -> Result<()> { let db = &services.db; - let config = &services.server.config; services .globals @@ -73,10 +72,7 @@ async fn fresh(services: &Services) -> Result<()> { // Create the admin room and server user on first run crate::admin::create_admin_room(services).boxed().await?; - warn!( - "Created new {} database with version {DATABASE_VERSION}", - config.database_backend, - ); + warn!("Created new RocksDB database with version {DATABASE_VERSION}"); Ok(()) } @@ -201,10 +197,7 @@ async fn migrate(services: &Services) -> Result<()> { } } - info!( - "Loaded {} database with schema version {DATABASE_VERSION}", - config.database_backend, - ); + info!("Loaded RocksDB database with schema version {DATABASE_VERSION}"); Ok(()) } From 4fe47903c24f3f9a0f33e871caecd96e7294dc49 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 10 Nov 2024 20:20:27 -0500 Subject: [PATCH 214/245] misc docs changes/improvements from example config Signed-off-by: strawberry --- docs/deploying/docker-compose.for-traefik.yml | 1 - docs/deploying/docker-compose.with-caddy.yml | 1 - docs/deploying/docker-compose.yml | 1 - docs/deploying/docker.md | 1 - docs/deploying/generic.md | 14 +++++++++++++- docs/troubleshooting.md | 9 +++++---- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/docs/deploying/docker-compose.for-traefik.yml b/docs/deploying/docker-compose.for-traefik.yml index ae93d52fa..b43164269 100644 --- a/docs/deploying/docker-compose.for-traefik.yml +++ b/docs/deploying/docker-compose.for-traefik.yml @@ -14,7 +14,6 @@ services: environment: CONDUWUIT_SERVER_NAME: your.server.name.example # EDIT THIS CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit - CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 # should match the loadbalancer traefik label CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' diff --git a/docs/deploying/docker-compose.with-caddy.yml b/docs/deploying/docker-compose.with-caddy.yml index 369242126..c080293f0 100644 --- a/docs/deploying/docker-compose.with-caddy.yml +++ b/docs/deploying/docker-compose.with-caddy.yml @@ -30,7 +30,6 @@ services: environment: CONDUWUIT_SERVER_NAME: example.com # EDIT THIS CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit - CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' diff --git a/docs/deploying/docker-compose.yml b/docs/deploying/docker-compose.yml index 26145c5ae..3b7d84ed1 100644 --- a/docs/deploying/docker-compose.yml +++ b/docs/deploying/docker-compose.yml @@ -14,7 +14,6 @@ services: environment: CONDUWUIT_SERVER_NAME: your.server.name # EDIT THIS CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit - CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' diff --git a/docs/deploying/docker.md b/docs/deploying/docker.md index 7b8fd1a2c..e9c49c716 100644 --- a/docs/deploying/docker.md +++ b/docs/deploying/docker.md @@ -40,7 +40,6 @@ When you have the image you can simply run it with docker run -d -p 8448:6167 \ -v db:/var/lib/conduwuit/ \ -e CONDUWUIT_SERVER_NAME="your.server.name" \ - -e CONDUWUIT_DATABASE_BACKEND="rocksdb" \ -e CONDUWUIT_ALLOW_REGISTRATION=false \ --name conduit $LINK ``` diff --git a/docs/deploying/generic.md b/docs/deploying/generic.md index 31dc18456..6fe9709b3 100644 --- a/docs/deploying/generic.md +++ b/docs/deploying/generic.md @@ -42,6 +42,9 @@ replace the binary / container image / etc. this will **NOT** work on conduwuit and you must configure delegation manually. This is not a mistake and no support for this feature will be added. +If you are using SQLite, you **MUST** migrate to RocksDB. You can use this +tool to migrate from SQLite to RocksDB: + See the `[global.well_known]` config section, or configure your web server appropriately to send the delegation responses. @@ -137,11 +140,20 @@ You will need to reverse proxy everything under following routes: You can optionally reverse proxy the following individual routes: - `/.well-known/matrix/client` and `/.well-known/matrix/server` if using -conduwuit to perform delegation +conduwuit to perform delegation (see the `[global.well_known]` config section) - `/.well-known/matrix/support` if using conduwuit to send the homeserver admin contact and support page (formerly known as MSC1929) - `/` if you would like to see `hewwo from conduwuit woof!` at the root +See the following spec pages for more details on these files: +- [`/.well-known/matrix/server`](https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixserver) +- [`/.well-known/matrix/client`](https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixclient) +- [`/.well-known/matrix/support`](https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixsupport) + +Examples of delegation: +- +- + ### Caddy Create `/etc/caddy/conf.d/conduwuit_caddyfile` and enter this (substitute for diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index c1499f3a1..74e19de76 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -47,10 +47,11 @@ and communicate with your host's DNS servers (host's `/etc/resolv.conf`) Some filesystems may not like RocksDB using [Direct IO](https://github.com/facebook/rocksdb/wiki/Direct-IO). Direct IO is for -non-buffered I/O which improves conduwuit performance, but at least FUSE is a -filesystem potentially known to not like this. See the [example -config](configuration/examples.md) for disabling it if needed. Issues from -Direct IO on unsupported filesystems are usually shown as startup errors. +non-buffered I/O which improves conduwuit performance and reduces system CPU +usage, but at least FUSE and possibly ZFS are filesystems potentially known +to not like this. See the [example config](configuration/examples.md) for +disabling it if needed. Issues from Direct IO on unsupported filesystems are +usually shown as startup errors. #### Database corruption From 4296d7174f97e380f4d8e28e0ebf7d89c26c9c4a Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 2 Nov 2024 21:25:13 -0400 Subject: [PATCH 215/245] add receive_ephemeral check for appservice EDU sending (if it even works) Signed-off-by: strawberry --- src/service/sending/sender.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 464d186b7..f42682931 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -524,8 +524,13 @@ impl Service { } }, SendingEvent::Edu(edu) => { - if let Ok(edu) = serde_json::from_slice(edu) { - edu_jsons.push(edu); + if appservice + .receive_ephemeral + .is_some_and(|receive_edus| receive_edus) + { + if let Ok(edu) = serde_json::from_slice(edu) { + edu_jsons.push(edu); + } } }, SendingEvent::Flush => {}, // flush only; no new content From fd2a0024809d5029fc679ff4008f97794fba4075 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 10 Nov 2024 20:30:34 -0500 Subject: [PATCH 216/245] dont build sentry or perf_measurements features for complement Signed-off-by: strawberry --- nix/pkgs/complement/default.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/nix/pkgs/complement/default.nix b/nix/pkgs/complement/default.nix index 399c4449b..36f124001 100644 --- a/nix/pkgs/complement/default.nix +++ b/nix/pkgs/complement/default.nix @@ -25,6 +25,7 @@ let "tokio_console" # sentry telemetry isn't useful for complement, disabled by default anyways "sentry_telemetry" + "perf_measurements" # the containers don't use or need systemd signal support "systemd" # this is non-functional on nix for some reason From 4f0bdb5194b8b496618fb8744a128f38b14a02b5 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 10 Nov 2024 21:20:38 -0500 Subject: [PATCH 217/245] general misc bug fixes and slight improvements Signed-off-by: strawberry --- src/api/client/membership.rs | 21 +++++----- src/api/router/auth.rs | 48 ++++++++++++---------- src/api/server/invite.rs | 26 +++++++++++- src/api/server/make_join.rs | 18 ++++----- src/api/server/send_leave.rs | 4 +- src/service/rooms/spaces/mod.rs | 53 +++++++++++-------------- src/service/rooms/state/mod.rs | 10 +++-- src/service/rooms/state_accessor/mod.rs | 41 +++++++++++-------- src/service/sending/send.rs | 2 +- src/service/sending/sender.rs | 6 ++- 10 files changed, 128 insertions(+), 101 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index bde8dee85..10e69f58f 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1240,8 +1240,8 @@ async fn make_join_request( ) -> Result<(federation::membership::prepare_join_event::v1::Response, OwnedServerName)> { let mut make_join_response_and_server = Err!(BadServerResponse("No server available to assist in joining.")); - let mut make_join_counter: u16 = 0; - let mut incompatible_room_version_count: u8 = 0; + let mut make_join_counter: usize = 0; + let mut incompatible_room_version_count: usize = 0; for remote_server in servers { if services.globals.server_is_ours(remote_server) { @@ -1264,28 +1264,25 @@ async fn make_join_request( make_join_counter = make_join_counter.saturating_add(1); if let Err(ref e) = make_join_response { - trace!("make_join ErrorKind string: {:?}", e.kind().to_string()); - - // converting to a string is necessary (i think) because ruma is forcing us to - // fill in the struct for M_INCOMPATIBLE_ROOM_VERSION - if e.kind().to_string().contains("M_INCOMPATIBLE_ROOM_VERSION") - || e.kind().to_string().contains("M_UNSUPPORTED_ROOM_VERSION") - { + if matches!( + e.kind(), + ErrorKind::IncompatibleRoomVersion { .. } | ErrorKind::UnsupportedRoomVersion + ) { incompatible_room_version_count = incompatible_room_version_count.saturating_add(1); } if incompatible_room_version_count > 15 { info!( "15 servers have responded with M_INCOMPATIBLE_ROOM_VERSION or M_UNSUPPORTED_ROOM_VERSION, \ - assuming that Conduwuit does not support the room {room_id}: {e}" + assuming that conduwuit does not support the room {room_id}: {e}" ); make_join_response_and_server = Err!(BadServerResponse("Room version is not supported by Conduwuit")); return make_join_response_and_server; } - if make_join_counter > 50 { + if make_join_counter > 40 { warn!( - "50 servers failed to provide valid make_join response, assuming no server can assist in joining." + "40 servers failed to provide valid make_join response, assuming no server can assist in joining." ); make_join_response_and_server = Err!(BadServerResponse("No server available to assist in joining.")); return make_join_response_and_server; diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 2552ddedc..68abf5e2c 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -13,6 +13,7 @@ use ruma::{ profile::{get_avatar_url, get_display_name, get_profile, get_profile_key, get_timezone_key}, voip::get_turn_server_info, }, + federation::openid::get_openid_userinfo, AuthScheme, IncomingRequest, Metadata, }, server_util::authorization::XMatrix, @@ -102,26 +103,6 @@ pub(super) async fn auth( } match (metadata.authentication, token) { - (_, Token::Invalid) => { - // OpenID endpoint uses a query param with the same name, drop this once query - // params for user auth are removed from the spec. This is required to make - // integration manager work. - if request.query.access_token.is_some() && request.parts.uri.path().contains("/openid/") { - Ok(Auth { - origin: None, - sender_user: None, - sender_device: None, - appservice_info: None, - }) - } else { - Err(Error::BadRequest( - ErrorKind::UnknownToken { - soft_logout: false, - }, - "Unknown access token.", - )) - } - }, (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info).await?), (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { Ok(Auth { @@ -132,7 +113,6 @@ pub(super) async fn auth( }) }, (AuthScheme::AccessToken, Token::None) => match metadata { - // TODO: can we check this better? &get_turn_server_info::v3::Request::METADATA => { if services.globals.config.turn_allow_guests { Ok(Auth { @@ -171,6 +151,32 @@ pub(super) async fn auth( ErrorKind::Unauthorized, "Only appservice access tokens should be used on this endpoint.", )), + (AuthScheme::None, Token::Invalid) => { + // OpenID federation endpoint uses a query param with the same name, drop this + // once query params for user auth are removed from the spec. This is + // required to make integration manager work. + if request.query.access_token.is_some() && metadata == &get_openid_userinfo::v1::Request::METADATA { + Ok(Auth { + origin: None, + sender_user: None, + sender_device: None, + appservice_info: None, + }) + } else { + Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )) + } + }, + (_, Token::Invalid) => Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )), } } diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index b30a1b584..edf80cd69 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -1,5 +1,6 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; +use base64::{engine::general_purpose, Engine as _}; use conduit::{err, utils, warn, Err, Error, PduEvent, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_invite}, @@ -125,8 +126,10 @@ pub(crate) async fn create_invite_route( invite_state.push(pdu.to_stripped_state_event()); - // If we are active in the room, the remote server will notify us about the join - // via /send + // If we are active in the room, the remote server will notify us about the + // join/invite through /send. If we are not in the room, we need to manually + // record the invited state for client /sync through update_membership(), and + // send the invite PDU to the relevant appservices. if !services .rooms .state_cache @@ -148,6 +151,25 @@ pub(crate) async fn create_invite_route( .await?; } + for appservice in services.appservice.read().await.values() { + if appservice.is_user_match(&invited_user) { + services + .sending + .send_appservice_request( + appservice.registration.clone(), + ruma::api::appservice::event::push_events::v1::Request { + events: vec![pdu.to_room_event()], + txn_id: general_purpose::URL_SAFE_NO_PAD + .encode(utils::calculate_hash(&[pdu.event_id.as_bytes()])) + .into(), + ephemeral: Vec::new(), + to_device: Vec::new(), + }, + ) + .await?; + } + } + Ok(create_invite::v2::Response { event: services .sending diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index af5700647..d5ea675e9 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -80,6 +80,14 @@ pub(crate) async fn create_join_event_template_route( } let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; + if !body.ver.contains(&room_version_id) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Room version not supported.", + )); + } let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -118,16 +126,6 @@ pub(crate) async fn create_join_event_template_route( None }; - let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; - if !body.ver.contains(&room_version_id) { - return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: room_version_id, - }, - "Room version not supported.", - )); - } - let (_pdu, mut pdu_json) = services .rooms .timeline diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index 448e5de34..e4f41833c 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -157,7 +157,5 @@ async fn create_leave_event( .room_servers(room_id) .ready_filter(|server| !services.globals.server_is_ours(server)); - services.sending.send_pdu_servers(servers, &pdu_id).await?; - - Ok(()) + services.sending.send_pdu_servers(servers, &pdu_id).await } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 37272dca8..0ef7ddf56 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -8,7 +8,7 @@ use std::{ }; use conduit::{ - checked, debug, debug_info, err, + checked, debug_info, err, utils::{math::usize_from_f64, IterStream}, Error, Result, }; @@ -234,27 +234,25 @@ impl Service { }); } - Ok( - if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { - let summary = self - .get_room_summary(current_room, children_pdus, &identifier) - .await; - if let Ok(summary) = summary { - self.roomid_spacehierarchy_cache.lock().await.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { - summary: summary.clone(), - }), - ); - - Some(SummaryAccessibility::Accessible(Box::new(summary))) - } else { - None - } + if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { + let summary = self + .get_room_summary(current_room, children_pdus, &identifier) + .await; + if let Ok(summary) = summary { + self.roomid_spacehierarchy_cache.lock().await.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: summary.clone(), + }), + ); + + Ok(Some(SummaryAccessibility::Accessible(Box::new(summary)))) } else { - None - }, - ) + Ok(None) + } + } else { + Ok(None) + } } /// Gets the summary of a space using solely federation @@ -393,7 +391,7 @@ impl Service { .is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) .await { - debug!("User is not allowed to see room {room_id}"); + debug_info!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); } @@ -615,16 +613,13 @@ impl Service { &self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, allowed_room_ids: &Vec, ) -> bool { - // Note: unwrap_or_default for bool means false match identifier { Identifier::ServerName(server_name) => { - let room_id: &RoomId = current_room; - // Checks if ACLs allow for the server to participate if self .services .event_handler - .acl_check(server_name, room_id) + .acl_check(server_name, current_room) .await .is_err() { @@ -645,8 +640,9 @@ impl Service { return true; } }, - } // Takes care of join rules - match join_rule { + } + match &join_rule { + SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, SpaceRoomJoinRule::Restricted => { for room in allowed_room_ids { match identifier { @@ -664,7 +660,6 @@ impl Service { } false }, - SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, // Invite only, Private, or Custom join rule _ => false, } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 71a3900cd..7d8200f09 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -295,20 +295,22 @@ impl Service { } #[tracing::instrument(skip_all, level = "debug")] - pub async fn summary_stripped(&self, invite: &PduEvent) -> Vec> { + pub async fn summary_stripped(&self, event: &PduEvent) -> Vec> { let cells = [ (&StateEventType::RoomCreate, ""), (&StateEventType::RoomJoinRules, ""), (&StateEventType::RoomCanonicalAlias, ""), (&StateEventType::RoomName, ""), (&StateEventType::RoomAvatar, ""), - (&StateEventType::RoomMember, invite.sender.as_str()), // Add recommended events + (&StateEventType::RoomMember, event.sender.as_str()), // Add recommended events + (&StateEventType::RoomEncryption, ""), + (&StateEventType::RoomTopic, ""), ]; let fetches = cells.iter().map(|(event_type, state_key)| { self.services .state_accessor - .room_state_get(&invite.room_id, event_type, state_key) + .room_state_get(&event.room_id, event_type, state_key) }); join_all(fetches) @@ -316,7 +318,7 @@ impl Service { .into_iter() .filter_map(Result::ok) .map(|e| e.to_stripped_state_event()) - .chain(once(invite.to_stripped_state_event())) + .chain(once(event.to_stripped_state_event())) .collect() } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index d51da8af9..4958c4eaf 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -10,7 +10,7 @@ use conduit::{ err, error, pdu::PduBuilder, utils::{math::usize_from_f64, ReadyExt}, - Error, PduEvent, Result, + Err, Error, Event, PduEvent, Result, }; use futures::StreamExt; use lru_cache::LruCache; @@ -29,7 +29,7 @@ use ruma::{ power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, topic::RoomTopicEventContent, }, - StateEventType, + StateEventType, TimelineEventType, }, room::RoomType, space::SpaceRoomJoinRule, @@ -408,34 +408,41 @@ impl Service { pub async fn user_can_redact( &self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool, ) -> Result { - if let Ok(event) = self + let redacting_event = self.services.timeline.get_pdu(redacts).await; + + if redacting_event + .as_ref() + .is_ok_and(|event| event.event_type() == &TimelineEventType::RoomCreate) + { + return Err!(Request(Forbidden("Redacting m.room.create is not safe, forbidding."))); + } + + if let Ok(pl_event_content) = self .room_state_get_content::(room_id, &StateEventType::RoomPowerLevels, "") .await { - let event: RoomPowerLevels = event.into(); - Ok(event.user_can_redact_event_of_other(sender) - || event.user_can_redact_own_event(sender) - && if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { + let pl_event: RoomPowerLevels = pl_event_content.into(); + Ok(pl_event.user_can_redact_event_of_other(sender) + || pl_event.user_can_redact_own_event(sender) + && if let Ok(redacting_event) = redacting_event { if federation { - pdu.sender.server_name() == sender.server_name() + redacting_event.sender.server_name() == sender.server_name() } else { - pdu.sender == sender + redacting_event.sender == sender } } else { false }) } else { // Falling back on m.room.create to judge power level - if let Ok(pdu) = self + if let Ok(room_create) = self .room_state_get(room_id, &StateEventType::RoomCreate, "") .await { - Ok(pdu.sender == sender - || if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { - pdu.sender == sender - } else { - false - }) + Ok(room_create.sender == sender + || redacting_event + .as_ref() + .is_ok_and(|redacting_event| redacting_event.sender == sender)) } else { Err(Error::bad_database( "No m.room.power_levels or m.room.create events in database for room", @@ -454,7 +461,7 @@ impl Service { /// Returns an empty vec if not a restricted room pub fn allowed_room_ids(&self, join_rule: JoinRule) -> Vec { - let mut room_ids = vec![]; + let mut room_ids = Vec::with_capacity(1); if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule { for rule in r.allow { if let AllowRule::RoomMembership(RoomMembership { diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 5bf48aaab..6a8f1b1bd 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -39,7 +39,7 @@ impl super::Service { .forbidden_remote_server_names .contains(dest) { - return Err!(Request(Forbidden(debug_warn!("Federation with this {dest} is not allowed.")))); + return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed.")))); } let actual = self.services.resolver.get_actual_dest(dest).await?; diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index f42682931..f5d875045 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -235,13 +235,15 @@ impl Service { fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> { let (mut allow, mut retry) = (true, false); statuses - .entry(dest) + .entry(dest.clone()) // TODO: can we avoid cloning? .and_modify(|e| match e { TransactionStatus::Failed(tries, time) => { // Fail if a request has failed recently (exponential backoff) let min = self.server.config.sender_timeout; let max = self.server.config.sender_retry_backoff_limit; - if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries) { + if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries) + && !matches!(dest, Destination::Appservice(_)) + { allow = false; } else { retry = true; From 72fb8371f9828b5c039883467ea35e3b41b4b42c Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 13 Nov 2024 17:08:16 -0500 Subject: [PATCH 218/245] link to migrating from conduit on the README Signed-off-by: strawberry --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 962139d64..4e97f1f00 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Artifacts](https://github.com/girlbossceo/conduwuit/actions/workflows/ci.yml/bad Visit the [conduwuit documentation](https://conduwuit.puppyirl.gay/) for more -information. +information and how to deploy/setup conduwuit. @@ -63,7 +63,9 @@ and we have no plans in stopping or slowing down any time soon! conduwuit is a complete drop-in replacement for Conduit. As long as you are using RocksDB, the only "migration" you need to do is replace the binary or container image. There -is no harm or additional steps required for using conduwuit. +is no harm or additional steps required for using conduwuit. See the +[Migrating from Conduit](https://conduwuit.puppyirl.gay/deploying/generic.html#migrating-from-conduit) section +on the generic deploying guide. From 011d44b749bc206732572c826b4ca9ae2a6111e3 Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 13 Nov 2024 20:06:25 -0500 Subject: [PATCH 219/245] add missing declared support for MSC3952 Signed-off-by: strawberry --- src/api/client/unversioned.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/api/client/unversioned.rs b/src/api/client/unversioned.rs index d5bb14e5d..3aee30c8b 100644 --- a/src/api/client/unversioned.rs +++ b/src/api/client/unversioned.rs @@ -53,6 +53,7 @@ pub(crate) async fn get_supported_versions_route( ("org.matrix.msc2946".to_owned(), true), /* spaces/hierarchy summaries (https://github.com/matrix-org/matrix-spec-proposals/pull/2946) */ ("org.matrix.msc3026.busy_presence".to_owned(), true), /* busy presence status (https://github.com/matrix-org/matrix-spec-proposals/pull/3026) */ ("org.matrix.msc3827".to_owned(), true), /* filtering of /publicRooms by room type (https://github.com/matrix-org/matrix-spec-proposals/pull/3827) */ + ("org.matrix.msc3952_intentional_mentions".to_owned(), true), /* intentional mentions (https://github.com/matrix-org/matrix-spec-proposals/pull/3952) */ ("org.matrix.msc3575".to_owned(), true), /* sliding sync (https://github.com/matrix-org/matrix-spec-proposals/pull/3575/files#r1588877046) */ ("org.matrix.msc3916.stable".to_owned(), true), /* authenticated media (https://github.com/matrix-org/matrix-spec-proposals/pull/3916) */ ("org.matrix.msc4180".to_owned(), true), /* stable flag for 3916 (https://github.com/matrix-org/matrix-spec-proposals/pull/4180) */ From 44a7ac07036915263f93e43f572dca8446d7ef9f Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 09:40:04 -0500 Subject: [PATCH 220/245] add debug_assert is_sorted for inline content types Signed-off-by: strawberry --- src/core/utils/content_disposition.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/core/utils/content_disposition.rs b/src/core/utils/content_disposition.rs index a2fe923c4..3a264a74f 100644 --- a/src/core/utils/content_disposition.rs +++ b/src/core/utils/content_disposition.rs @@ -45,9 +45,10 @@ pub fn content_disposition_type(content_type: Option<&str>) -> ContentDispositio return ContentDispositionType::Attachment; }; - // is_sorted is unstable - /* debug_assert!(ALLOWED_INLINE_CONTENT_TYPES.is_sorted(), - * "ALLOWED_INLINE_CONTENT_TYPES is not sorted"); */ + debug_assert!( + ALLOWED_INLINE_CONTENT_TYPES.is_sorted(), + "ALLOWED_INLINE_CONTENT_TYPES is not sorted" + ); let content_type: Cow<'_, str> = content_type .split(';') From dac1a01216e53c4aca4e54057db487dc7682f30a Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 09:43:58 -0500 Subject: [PATCH 221/245] update generated example config Signed-off-by: strawberry --- conduwuit-example.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/conduwuit-example.toml b/conduwuit-example.toml index aa0d1e5df..2f3da71f4 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -592,6 +592,10 @@ # #log_colors = true +# configures the span events which will be outputted with the log +# +#log_span_events = "none" + # OpenID token expiration/TTL in seconds # # These are the OpenID tokens that are primarily used for Matrix account From 3f69f2ee73960a8cc0f0a3f7a2e1202ad43928e6 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 09:44:29 -0500 Subject: [PATCH 222/245] replace deprecated sha-1 crate, try to reduce some unnecessary crates/features Signed-off-by: strawberry --- Cargo.lock | 49 +++++------------------------------- Cargo.toml | 62 +++++++++++++++++++++++++++------------------- src/api/Cargo.toml | 2 +- 3 files changed, 43 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0e1845dad..65eab0b52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -676,7 +676,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "sha-1", + "sha1", "tokio", "tracing", ] @@ -720,7 +720,7 @@ dependencies = [ "serde_json", "serde_regex", "serde_yaml", - "thiserror 1.0.69", + "thiserror 2.0.3", "tikv-jemalloc-ctl", "tikv-jemalloc-sys", "tikv-jemallocator", @@ -1753,9 +1753,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.8" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", @@ -1766,7 +1766,6 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower 0.4.13", "tower-service", "tracing", ] @@ -2080,11 +2079,9 @@ checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f" dependencies = [ "base64 0.21.7", "js-sys", - "pem", "ring", "serde", "serde_json", - "simple_asn1", ] [[package]] @@ -2662,16 +2659,6 @@ dependencies = [ "syn 2.0.87", ] -[[package]] -name = "pem" -version = "3.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" -dependencies = [ - "base64 0.22.1", - "serde", -] - [[package]] name = "percent-encoding" version = "2.3.1" @@ -3535,11 +3522,10 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "sanitize-filename" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ed72fbaf78e6f2d41744923916966c4fbe3d7c74e3037a8ee482f1115572603" +checksum = "bc984f4f9ceb736a7bb755c3e3bd17dc56370af2600c9780dcc48c66453da34d" dependencies = [ - "lazy_static", "regex", ] @@ -3827,17 +3813,6 @@ dependencies = [ "unsafe-libyaml", ] -[[package]] -name = "sha-1" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha1" version = "0.10.6" @@ -3920,18 +3895,6 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" -[[package]] -name = "simple_asn1" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" -dependencies = [ - "num-bigint", - "num-traits", - "thiserror 1.0.69", - "time", -] - [[package]] name = "siphasher" version = "0.3.11" diff --git a/Cargo.toml b/Cargo.toml index dde005a31..a84ff79ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,20 +46,20 @@ default-features = false features = ["parse"] [workspace.dependencies.sanitize-filename] -version = "0.5.0" +version = "0.6.0" [workspace.dependencies.jsonwebtoken] version = "9.3.0" +default-features = false [workspace.dependencies.base64] version = "0.22.1" +default-features = false # used for TURN server authentication [workspace.dependencies.hmac] version = "0.12.1" - -[workspace.dependencies.sha-1] -version = "0.10.1" +default-features = false # used for checking if an IP is in specific subnets / CIDR ranges easier [workspace.dependencies.ipaddress] @@ -70,16 +70,16 @@ version = "0.8.5" # Used for the http request / response body type for Ruma endpoints used with reqwest [workspace.dependencies.bytes] -version = "1.7.2" +version = "1.8.0" [workspace.dependencies.http-body-util] -version = "0.1.1" +version = "0.1.2" [workspace.dependencies.http] version = "1.1.0" [workspace.dependencies.regex] -version = "1.10.6" +version = "1.11.1" [workspace.dependencies.axum] version = "0.7.5" @@ -95,7 +95,7 @@ features = [ ] [workspace.dependencies.axum-extra] -version = "0.9.3" +version = "0.9.4" default-features = false features = ["typed-header", "tracing"] @@ -116,7 +116,7 @@ default-features = false features = ["util"] [workspace.dependencies.tower-http] -version = "0.6.0" +version = "0.6.1" default-features = false features = [ "add-extension", @@ -130,6 +130,8 @@ features = [ [workspace.dependencies.rustls] version = "0.23.16" +default-features = false +features = ["aws_lc_rs"] [workspace.dependencies.reqwest] version = "0.12.9" @@ -147,7 +149,7 @@ default-features = false features = ["rc"] [workspace.dependencies.serde_json] -version = "1.0.124" +version = "1.0.132" default-features = false features = ["raw_value"] @@ -189,9 +191,11 @@ version = "0.1.40" default-features = false [workspace.dependencies.tracing-subscriber] version = "0.3.18" -features = ["env-filter"] +default-features = false +features = ["env-filter", "std", "tracing", "tracing-log", "ansi", "fmt"] [workspace.dependencies.tracing-core] version = "0.1.32" +default-features = false # for URL previews [workspace.dependencies.webpage] @@ -200,12 +204,14 @@ default-features = false # used for conduit's CLI and admin room command parsing [workspace.dependencies.clap] -version = "4.5.20" +version = "4.5.21" default-features = false features = [ "std", "derive", "help", + #"color", Do we need these? + #"unicode", "usage", "error-context", "string", @@ -217,7 +223,7 @@ default-features = false features = ["std", "async-await"] [workspace.dependencies.tokio] -version = "1.40.0" +version = "1.41.1" default-features = false features = [ "fs", @@ -238,7 +244,7 @@ version = "0.8.5" # Validating urls in config, was already a transitive dependency [workspace.dependencies.url] -version = "2.5.0" +version = "2.5.3" default-features = false features = ["serde"] @@ -258,26 +264,23 @@ features = [ ] [workspace.dependencies.hyper-util] -# 0.1.9 and above causes DNS issues -version = "=0.1.8" +version = "0.1.10" default-features = false features = [ - "client", "server-auto", "server-graceful", - "service", "tokio", ] # to support multiple variations of setting a config option [workspace.dependencies.either] -version = "1.11.0" +version = "1.13.0" default-features = false features = ["serde"] # Used for reading the configuration from conduwuit.toml & environment variables [workspace.dependencies.figment] -version = "0.10.18" +version = "0.10.19" default-features = false features = ["env", "toml"] @@ -287,11 +290,13 @@ default-features = false # Used for conduit::Error type [workspace.dependencies.thiserror] -version = "1.0.63" +version = "2.0.3" +default-features = false # Used when hashing the state [workspace.dependencies.ring] version = "0.17.8" +default-features = false # Used to make working with iterators easier, was already a transitive depdendency [workspace.dependencies.itertools] @@ -307,7 +312,7 @@ version = "2.1.1" version = "0.4.0" [workspace.dependencies.async-trait] -version = "0.1.81" +version = "0.1.83" [workspace.dependencies.lru-cache] version = "0.1.2" @@ -363,9 +368,13 @@ features = [ "bzip2", ] -# optional SHA256 media keys feature [workspace.dependencies.sha2] version = "0.10.8" +default-features = false + +[workspace.dependencies.sha1] +version = "0.10.6" +default-features = false # optional opentelemetry, performance measurements, flamegraphs, etc for performance measurements and monitoring [workspace.dependencies.opentelemetry] @@ -433,7 +442,8 @@ default-features = false features = ["resource"] [workspace.dependencies.sd-notify] -version = "0.4.1" +version = "0.4.3" +default-features = false [workspace.dependencies.hardened_malloc-rs] version = "0.1.2" @@ -456,12 +466,12 @@ default-features = false version = "0.1" [workspace.dependencies.syn] -version = "2.0.76" +version = "2.0.87" default-features = false features = ["full", "extra-traits"] [workspace.dependencies.quote] -version = "1.0.36" +version = "1.0.37" [workspace.dependencies.proc-macro2] version = "1.0.89" diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 6e37cb407..a0fc09ded 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -59,7 +59,7 @@ ruma.workspace = true serde_html_form.workspace = true serde_json.workspace = true serde.workspace = true -sha-1.workspace = true +sha1.workspace = true tokio.workspace = true tracing.workspace = true From b4d809c68157a09bbfd2381bcea04407bf3b3dd2 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 09:49:54 -0500 Subject: [PATCH 223/245] add more checks for gh pages deployment workflow Signed-off-by: strawberry --- .github/workflows/documentation.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 506a87d95..ea720c43c 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -39,7 +39,7 @@ concurrency: jobs: docs: name: Documentation and GitHub Pages - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 permissions: pages: write @@ -57,7 +57,7 @@ jobs: uses: actions/checkout@v4 - name: Setup GitHub Pages - if: github.event_name != 'pull_request' + if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') && (github.event_name != 'pull_request') && (github.event.pull_request.user.login == 'girlbossceo') uses: actions/configure-pages@v5 - uses: nixbuild/nix-quick-install-action@master @@ -139,12 +139,12 @@ jobs: compression-level: 0 - name: Upload generated documentation (book) as GitHub Pages artifact - if: github.event_name != 'pull_request' + if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') && (github.event_name != 'pull_request') && (github.event.pull_request.user.login == 'girlbossceo') uses: actions/upload-pages-artifact@v3 with: path: public - name: Deploy to GitHub Pages - if: github.event_name != 'pull_request' + if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') && (github.event_name != 'pull_request') && (github.event.pull_request.user.login == 'girlbossceo') id: deployment uses: actions/deploy-pages@v4 From c1f553cf4f938da469b69bfd37060ca65a827762 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 12:25:47 -0500 Subject: [PATCH 224/245] bump rocksdb to v9.7.4, and ruwuma Signed-off-by: strawberry --- Cargo.lock | 55 ++++++++++++++++++------------------ Cargo.toml | 2 +- deps/rust-rocksdb/Cargo.toml | 2 +- flake.lock | 8 +++--- flake.nix | 2 +- 5 files changed, 34 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 65eab0b52..b56005ff3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -191,9 +191,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.7" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "49c41b948da08fb481a94546cd874843adc1142278b0af4badf9b1b78599d68d" dependencies = [ "async-trait", "axum-core", @@ -257,9 +257,9 @@ dependencies = [ [[package]] name = "axum-extra" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73c3220b188aea709cf1b6c5f9b01c3bd936bb08bd2b5184a12b35ac8131b1f9" +checksum = "37634d71e9f3c35cfb1c30c87c7cba500d55892f04c2dbe6a99383c664b820b0" dependencies = [ "axum", "axum-core", @@ -275,7 +275,6 @@ dependencies = [ "tower 0.5.1", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -487,9 +486,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -1264,9 +1263,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -1740,9 +1739,9 @@ dependencies = [ [[package]] name = "hyper-timeout" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" dependencies = [ "hyper", "hyper-util", @@ -3128,7 +3127,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "assign", "js_int", @@ -3150,7 +3149,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", "ruma-common", @@ -3162,7 +3161,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "as_variant", "assign", @@ -3185,7 +3184,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "as_variant", "base64 0.22.1", @@ -3215,7 +3214,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3239,7 +3238,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "bytes", "http", @@ -3257,7 +3256,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", "thiserror 2.0.3", @@ -3266,7 +3265,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", "ruma-common", @@ -3276,7 +3275,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "cfg-if", "once_cell", @@ -3292,7 +3291,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", "ruma-common", @@ -3304,7 +3303,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "headers", "http", @@ -3317,7 +3316,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3333,7 +3332,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=67ffedabbf43e1ff6934df0fbf770b21e101406f#67ffedabbf43e1ff6934df0fbf770b21e101406f" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "futures-util", "itertools 0.13.0", @@ -3348,8 +3347,8 @@ dependencies = [ [[package]] name = "rust-librocksdb-sys" -version = "0.28.0+9.7.3" -source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=c1e5523eae095a893deaf9056128c7dbc2d5fd73#c1e5523eae095a893deaf9056128c7dbc2d5fd73" +version = "0.29.0+9.7.4" +source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=2bc5495a9f8f75073390c326b47ee5928ab7c7f0#2bc5495a9f8f75073390c326b47ee5928ab7c7f0" dependencies = [ "bindgen", "bzip2-sys", @@ -3365,8 +3364,8 @@ dependencies = [ [[package]] name = "rust-rocksdb" -version = "0.31.0" -source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=c1e5523eae095a893deaf9056128c7dbc2d5fd73#c1e5523eae095a893deaf9056128c7dbc2d5fd73" +version = "0.33.0" +source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=2bc5495a9f8f75073390c326b47ee5928ab7c7f0#2bc5495a9f8f75073390c326b47ee5928ab7c7f0" dependencies = [ "libc", "rust-librocksdb-sys", diff --git a/Cargo.toml b/Cargo.toml index a84ff79ff..814a435b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -321,7 +321,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "67ffedabbf43e1ff6934df0fbf770b21e101406f" +rev = "2ab432fba19eb8862c594d24af39d8f9f6b4eac6" features = [ "compat", "rand", diff --git a/deps/rust-rocksdb/Cargo.toml b/deps/rust-rocksdb/Cargo.toml index 8c168b24f..908a2911c 100644 --- a/deps/rust-rocksdb/Cargo.toml +++ b/deps/rust-rocksdb/Cargo.toml @@ -27,7 +27,7 @@ malloc-usable-size = ["rust-rocksdb/malloc-usable-size"] [dependencies.rust-rocksdb] git = "https://github.com/girlbossceo/rust-rocksdb-zaidoon1" -rev = "c1e5523eae095a893deaf9056128c7dbc2d5fd73" +rev = "2bc5495a9f8f75073390c326b47ee5928ab7c7f0" #branch = "master" default-features = false diff --git a/flake.lock b/flake.lock index 271a21541..7740e9254 100644 --- a/flake.lock +++ b/flake.lock @@ -922,16 +922,16 @@ "rocksdb": { "flake": false, "locked": { - "lastModified": 1729712930, - "narHash": "sha256-jlp4kPkRTpoJaUdobEoHd8rCGAQNBy4ZHZ6y5zL/ibw=", + "lastModified": 1731690620, + "narHash": "sha256-Xd4TJYqPERMJLXaGa6r6Ny1Wlw8Uy5Cyf/8q7nS58QM=", "owner": "girlbossceo", "repo": "rocksdb", - "rev": "871eda6953c3f399aae39808dcfccdd014885beb", + "rev": "292446aa2bc41699204d817a1e4b091679a886eb", "type": "github" }, "original": { "owner": "girlbossceo", - "ref": "v9.7.3", + "ref": "v9.7.4", "repo": "rocksdb", "type": "github" } diff --git a/flake.nix b/flake.nix index 85b7baa0e..113757a73 100644 --- a/flake.nix +++ b/flake.nix @@ -9,7 +9,7 @@ flake-utils.url = "github:numtide/flake-utils?ref=main"; nix-filter.url = "github:numtide/nix-filter?ref=main"; nixpkgs.url = "github:NixOS/nixpkgs?ref=nixpkgs-unstable"; - rocksdb = { url = "github:girlbossceo/rocksdb?ref=v9.7.3"; flake = false; }; + rocksdb = { url = "github:girlbossceo/rocksdb?ref=v9.7.4"; flake = false; }; liburing = { url = "github:axboe/liburing?ref=master"; flake = false; }; }; From a9c280bd4cd0616c57c9df28a11c3bc48ae8b5ba Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 12:58:02 -0500 Subject: [PATCH 225/245] document NAT hairpinning/loopback if needed Signed-off-by: strawberry --- docs/deploying/generic.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/deploying/generic.md b/docs/deploying/generic.md index 6fe9709b3..f0b85a25c 100644 --- a/docs/deploying/generic.md +++ b/docs/deploying/generic.md @@ -68,13 +68,25 @@ sudo useradd -r --shell /usr/bin/nologin --no-create-home conduwuit ## Forwarding ports in the firewall or the router -conduwuit uses the ports 443 and 8448 both of which need to be open in the -firewall. +Matrix's default federation port is port 8448, and clients must be using port 443. +If you would like to use only port 443, or a different port, you will need to setup +delegation. conduwuit has config options for doing delegation, or you can configure +your reverse proxy to manually serve the necessary JSON files to do delegation +(see the `[global.well_known]` config section). If conduwuit runs behind a router or in a container and has a different public IP address than the host system these public ports need to be forwarded directly or indirectly to the port mentioned in the config. +Note for NAT users; if you have trouble connecting to your server from the inside +of your network, you need to research your router and see if it supports "NAT +hairpinning" or "NAT loopback". + +If your router does not support this feature, you need to research doing local +DNS overrides and force your Matrix DNS records to use your local IP internally. +This can be done at the host level using `/etc/hosts`. If you need this to be +on the network level, consider something like NextDNS or Pi-Hole. + ## Setting up a systemd service The systemd unit for conduwuit can be found From c23786d37f207f45632d8288affbcd51bfb5e5c8 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 12:59:05 -0500 Subject: [PATCH 226/245] dont try to backfill empty, private rooms Signed-off-by: strawberry --- src/api/client/membership.rs | 2 +- src/service/rooms/timeline/mod.rs | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 10e69f58f..c61185a7c 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1481,7 +1481,7 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, .await { if let Err(e) = remote_leave_room(services, user_id, room_id).await { - warn!("Failed to leave room {user_id} remotely: {e}"); + warn!(%user_id, "Failed to leave room {room_id} remotely: {e}"); // Don't tell the client about this error } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 281879d2f..2faa1c40b 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1033,6 +1033,22 @@ impl Service { #[tracing::instrument(skip(self))] pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { + if self + .services + .state_cache + .room_joined_count(room_id) + .await + .is_ok_and(|count| count <= 1) + && !self + .services + .state_accessor + .is_world_readable(room_id) + .await + { + // Room is empty (1 user or none), there is no one that can backfill + return Ok(()); + } + let first_pdu = self .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) .await? @@ -1060,20 +1076,8 @@ impl Service { } }); - let room_alias_servers = self - .services - .alias - .local_aliases_for_room(room_id) - .ready_filter_map(|alias| { - self.services - .globals - .server_is_ours(alias.server_name()) - .then_some(alias.server_name()) - }); - let mut servers = room_mods .stream() - .chain(room_alias_servers) .map(ToOwned::to_owned) .chain( self.services From 9783bc78ba096f75db2c529f3e2a7f6cb76f51fe Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 13:58:44 -0500 Subject: [PATCH 227/245] remove sentry_telemetry from default features Signed-off-by: strawberry --- src/main/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/Cargo.toml b/src/main/Cargo.toml index b3390bfb1..b91229425 100644 --- a/src/main/Cargo.toml +++ b/src/main/Cargo.toml @@ -44,7 +44,6 @@ default = [ "jemalloc", "jemalloc_stats", "release_max_log_level", - "sentry_telemetry", "systemd", "zstd_compression", ] From 666989f74ce8a80b3d24132393c2e2e331fe719c Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 13:19:32 -0500 Subject: [PATCH 228/245] delete trivy as lately its been terribly unreliable Signed-off-by: strawberry --- .github/workflows/trivy.yml | 42 ------------------------------------- docs/differences.md | 3 +-- 2 files changed, 1 insertion(+), 44 deletions(-) delete mode 100644 .github/workflows/trivy.yml diff --git a/.github/workflows/trivy.yml b/.github/workflows/trivy.yml deleted file mode 100644 index 1f0dd7df2..000000000 --- a/.github/workflows/trivy.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Trivy code and vulnerability scanning - -on: - pull_request: - push: - branches: - - main - tags: - - '*' - schedule: - - cron: '00 12 * * *' - -permissions: - contents: read - -jobs: - trivy-scan: - name: Trivy Scan - runs-on: ubuntu-latest - permissions: - contents: read - security-events: write - actions: read - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Run Trivy code and vulnerability scanner on repo - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: repo - format: sarif - output: trivy-results.sarif - severity: CRITICAL,HIGH,MEDIUM,LOW - - - name: Run Trivy code and vulnerability scanner on filesystem - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: fs - format: sarif - output: trivy-results.sarif - severity: CRITICAL,HIGH,MEDIUM,LOW diff --git a/docs/differences.md b/docs/differences.md index 6815d2485..18ea7a1ff 100644 --- a/docs/differences.md +++ b/docs/differences.md @@ -241,8 +241,7 @@ both new users and power users - Fixed every single clippy (default lints) and rustc warnings, including some that were performance related or potential safety issues / unsoundness - Add a **lot** of other clippy and rustc lints and a rustfmt.toml file -- Repo uses [Renovate](https://docs.renovatebot.com/), -[Trivy](https://github.com/aquasecurity/trivy-action), and keeps ALL +- Repo uses [Renovate](https://docs.renovatebot.com/) and keeps ALL dependencies as up to date as possible - Purge unmaintained/irrelevant/broken database backends (heed, sled, persy) and other unnecessary code or overhead From f897b4daeea8147a96286e96e6d828d364136dd8 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 13:20:47 -0500 Subject: [PATCH 229/245] ci: remove all free runner space steps due to flakiness Signed-off-by: strawberry --- .github/workflows/ci.yml | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d253f695..f4dcb88f8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,21 +67,6 @@ jobs: name: Test runs-on: ubuntu-24.04 steps: - - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@main - - - name: Free up more runner space - run: | - set +o pipefail - # large docker images - sudo docker image prune --all --force || true - # large packages - sudo apt-get purge -y 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true - sudo apt-get clean - # large folders - sudo rm -rf /var/lib/apt/lists/* /usr/local/games /usr/local/sqlpackage /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/local/share/vcpkg /usr/local/julia* /opt/mssql-tools /usr/share/vim /usr/share/postgresql /usr/share/apache-maven-* /usr/share/R /usr/share/alsa /usr/share/miniconda /usr/share/grub /usr/share/gradle-* /usr/share/locale /usr/share/texinfo /usr/share/kotlinc /usr/share/swift /usr/share/sbt /usr/share/ri /usr/share/icons /usr/share/java /usr/share/fonts /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/R /usr/lib/postgresql /usr/lib/heroku - set -o pipefail - - name: Sync repository uses: actions/checkout@v4 @@ -238,9 +223,6 @@ jobs: - target: aarch64-linux-musl - target: x86_64-linux-musl steps: - - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@main - - name: Sync repository uses: actions/checkout@v4 @@ -449,6 +431,7 @@ jobs: steps: - name: Sync repository uses: actions/checkout@v4 + - name: Tag comparison check if: ${{ startsWith(github.ref, 'refs/tags/v') && !endsWith(github.ref, '-rc') }} run: | @@ -459,14 +442,17 @@ jobs: echo '# WARNING: Attempting to run this workflow for a tag that is not the latest repo tag. Aborting.' >> $GITHUB_STEP_SUMMARY exit 1 fi + # use sccache for Rust - name: Run sccache-cache if: (github.event.pull_request.draft != true) && (vars.DOCKER_USERNAME != '') && (vars.GITLAB_USERNAME != '') && (vars.SCCACHE_ENDPOINT != '') && (github.event.pull_request.user.login != 'renovate[bot]') uses: mozilla-actions/sccache-action@main + # use rust-cache - uses: Swatinem/rust-cache@v2 with: cache-all-crates: "true" + # Nix can't do portable macOS builds yet - name: Build macOS x86_64 binary if: ${{ matrix.os == 'macos-13' }} @@ -474,22 +460,26 @@ jobs: CONDUWUIT_VERSION_EXTRA="$(git rev-parse --short HEAD)" cargo build --release cp -v -f target/release/conduit conduwuit-macos-x86_64 otool -L conduwuit-macos-x86_64 + # quick smoke test of the x86_64 macOS binary - name: Run x86_64 macOS release binary if: ${{ matrix.os == 'macos-13' }} run: | ./conduwuit-macos-x86_64 --version + - name: Build macOS arm64 binary if: ${{ matrix.os == 'macos-latest' }} run: | CONDUWUIT_VERSION_EXTRA="$(git rev-parse --short HEAD)" cargo build --release cp -v -f target/release/conduit conduwuit-macos-arm64 otool -L conduwuit-macos-arm64 + # quick smoke test of the arm64 macOS binary - name: Run arm64 macOS release binary if: ${{ matrix.os == 'macos-latest' }} run: | ./conduwuit-macos-arm64 --version + - name: Upload macOS x86_64 binary if: ${{ matrix.os == 'macos-13' }} uses: actions/upload-artifact@v4 @@ -497,6 +487,7 @@ jobs: name: conduwuit-macos-x86_64 path: conduwuit-macos-x86_64 if-no-files-found: error + - name: Upload macOS arm64 binary if: ${{ matrix.os == 'macos-latest' }} uses: actions/upload-artifact@v4 From 6b1b464abcecacea98e06040aba679e9bdc3cec9 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 14:48:10 -0500 Subject: [PATCH 230/245] add missing knock_restricted room type to /publicRooms Signed-off-by: strawberry --- src/api/client/directory.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 6cf7b13f5..6120c7b39 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -407,7 +407,8 @@ async fn public_rooms_chunk(services: &Services, room_id: OwnedRoomId) -> Public .room_state_get_content(&room_id, &StateEventType::RoomJoinRules, "") .map_ok(|c: RoomJoinRulesEventContent| match c.join_rule { JoinRule::Public => PublicRoomJoinRule::Public, - JoinRule::Knock => PublicRoomJoinRule::Knock, + JoinRule::Knock => "knock".into(), + JoinRule::KnockRestricted(_) => "knock_restricted".into(), _ => "invite".into(), }) .await From 9c95a74d56bf80fb1f09984a9c05b16d25d320da Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 16:35:58 -0500 Subject: [PATCH 231/245] fix getting canonical alias server for backfill Signed-off-by: strawberry --- src/service/rooms/timeline/mod.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 2faa1c40b..59fc8e930 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -4,6 +4,7 @@ use std::{ cmp, collections::{BTreeMap, HashSet}, fmt::Write, + iter::once, sync::Arc, }; @@ -1076,9 +1077,20 @@ impl Service { } }); + let canonical_room_alias_server = once( + self.services + .state_accessor + .get_canonical_alias(room_id) + .await, + ) + .filter_map(Result::ok) + .map(|alias| alias.server_name().to_owned()) + .stream(); + let mut servers = room_mods .stream() .map(ToOwned::to_owned) + .chain(canonical_room_alias_server) .chain( self.services .server From be5a04f47cf1da5629239804601978ddcb2a1db1 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 17:09:36 -0500 Subject: [PATCH 232/245] ci: install liburing-dev Signed-off-by: strawberry --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f4dcb88f8..96a1188b6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,6 +67,10 @@ jobs: name: Test runs-on: ubuntu-24.04 steps: + - name: Install liburing + run: | + sudo apt install liburing-dev -y + - name: Sync repository uses: actions/checkout@v4 From 4b652f5236f3482316fc8dfc522f705c3b85b586 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 17:50:39 -0500 Subject: [PATCH 233/245] ok cargo doc Signed-off-by: strawberry --- src/core/config/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 4bba14554..cb9d087bb 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1524,7 +1524,7 @@ pub struct WellKnownConfig { /// The server URL that the client well-known file will serve. This should /// not contain a port, and should just be a valid HTTPS URL. /// - /// example: "https://matrix.example.com" + /// example: "" pub client: Option, pub support_page: Option, pub support_role: Option, From 59834a4b05784c6e5e9ba12c5c5cc06f5ba98825 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 14 Nov 2024 22:43:18 +0000 Subject: [PATCH 234/245] add is_read_only()/is_secondary() to Engine Signed-off-by: Jason Volk --- src/database/database.rs | 4 ++-- src/database/engine.rs | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/database/database.rs b/src/database/database.rs index bf8c88555..40aec3123 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -38,11 +38,11 @@ impl Database { #[inline] #[must_use] - pub fn is_read_only(&self) -> bool { self.db.secondary || self.db.read_only } + pub fn is_read_only(&self) -> bool { self.db.is_read_only() } #[inline] #[must_use] - pub fn is_secondary(&self) -> bool { self.db.secondary } + pub fn is_secondary(&self) -> bool { self.db.is_secondary() } } impl Index<&str> for Database { diff --git a/src/database/engine.rs b/src/database/engine.rs index 99d971ed6..b57fd75e5 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -274,6 +274,14 @@ impl Engine { result(self.db.property_value_cf(cf, name)) .and_then(|val| val.map_or_else(|| Err!("Property {name:?} not found."), Ok)) } + + #[inline] + #[must_use] + pub fn is_read_only(&self) -> bool { self.secondary || self.read_only } + + #[inline] + #[must_use] + pub fn is_secondary(&self) -> bool { self.secondary } } pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result<()> { From 20836cc3dbc2e22c6d7da99ec199930b6d4c7ad4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 14 Nov 2024 22:44:18 +0000 Subject: [PATCH 235/245] flush=false for database-backup in read-only/secondary modes; improve error Signed-off-by: Jason Volk --- src/admin/server/commands.rs | 2 +- src/database/engine.rs | 14 ++++++++------ src/service/globals/data.rs | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index f5879b037..94f695ceb 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -107,7 +107,7 @@ pub(super) async fn backup_database(&self) -> Result { .runtime() .spawn_blocking(move || match globals.db.backup() { Ok(()) => String::new(), - Err(e) => (*e).to_string(), + Err(e) => e.to_string(), }) .await?; diff --git a/src/database/engine.rs b/src/database/engine.rs index b57fd75e5..1fa53b012 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -17,6 +17,7 @@ use rocksdb::{ use crate::{ opts::{cf_options, db_options}, or_else, result, + util::map_err, }; pub struct Engine { @@ -183,19 +184,20 @@ impl Engine { } #[tracing::instrument(skip(self))] - pub fn backup(&self) -> Result<(), Box> { + pub fn backup(&self) -> Result { let config = &self.server.config; let path = config.database_backup_path.as_ref(); if path.is_none() || path.is_some_and(|path| path.as_os_str().is_empty()) { return Ok(()); } - let options = BackupEngineOptions::new(path.expect("valid database backup path"))?; - let mut engine = BackupEngine::open(&options, &self.env)?; + let options = BackupEngineOptions::new(path.expect("valid database backup path")).map_err(map_err)?; + let mut engine = BackupEngine::open(&options, &self.env).map_err(map_err)?; if config.database_backups_to_keep > 0 { - if let Err(e) = engine.create_new_backup_flush(&self.db, true) { - return Err(Box::new(e)); - } + let flush = !self.is_read_only(); + engine + .create_new_backup_flush(&self.db, flush) + .map_err(map_err)?; let engine_info = engine.get_backup_info(); let info = &engine_info.last().expect("backup engine info is not empty"); diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index bcfe101ef..f715e944a 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -73,7 +73,7 @@ impl Data { } #[inline] - pub fn backup(&self) -> Result<(), Box> { self.db.db.backup() } + pub fn backup(&self) -> Result { self.db.db.backup() } #[inline] pub fn backup_list(&self) -> Result { self.db.db.backup_list() } From 5f625216aa027aa000037c26fee461684e25689c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 14 Nov 2024 23:35:53 +0000 Subject: [PATCH 236/245] slight optimizations for statediff calculate with_capacity for set/get_statediff() etc Signed-off-by: Jason Volk --- src/service/rooms/state_compressor/mod.rs | 50 +++++++++++++++-------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 6b520ad3d..06054f0d3 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -5,6 +5,7 @@ use std::{ sync::{Arc, Mutex}, }; +use arrayvec::ArrayVec; use conduit::{ at, checked, err, expected, utils, utils::{bytes, math::usize_from_f64}, @@ -37,7 +38,7 @@ struct Data { #[derive(Clone)] struct StateDiff { - parent: Option, + parent: Option, added: Arc, removed: Arc, } @@ -165,17 +166,20 @@ impl Service { } pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent { - let mut v = shortstatekey.to_be_bytes().to_vec(); - v.extend_from_slice( - &self - .services - .short - .get_or_create_shorteventid(event_id) - .await - .to_be_bytes(), - ); + const SIZE: usize = size_of::(); + + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - v.try_into().expect("we checked the size above") + let mut v = ArrayVec::::new(); + v.extend(shortstatekey.to_be_bytes()); + v.extend(shorteventid.to_be_bytes()); + v.as_ref() + .try_into() + .expect("failed to create CompressedStateEvent") } /// Returns shortstatekey, event id @@ -185,11 +189,12 @@ impl Service { ) -> Result<(ShortStateKey, Arc)> { use utils::u64_from_u8; - let shortstatekey = u64_from_u8(&compressed_event[0..size_of::()]); + let shortstatekey = u64_from_u8(&compressed_event[0..size_of::()]); + let shorteventid = u64_from_u8(&compressed_event[size_of::()..]); let event_id = self .services .short - .get_eventid_from_short(u64_from_u8(&compressed_event[size_of::()..])) + .get_eventid_from_short(shorteventid) .await?; Ok((shortstatekey, event_id)) @@ -415,9 +420,12 @@ impl Service { .ok() .take_if(|parent| *parent != 0); + debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); + let num_values = value.len() / STRIDE; + let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); + let mut added = HashSet::with_capacity(num_values); + let mut removed = HashSet::with_capacity(num_values); let mut i = STRIDE; while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { @@ -434,6 +442,8 @@ impl Service { i = expected!(i + 2 * STRIDE); } + added.shrink_to_fit(); + removed.shrink_to_fit(); Ok(StateDiff { parent, added: Arc::new(added), @@ -442,7 +452,15 @@ impl Service { } fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) { - let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); + let mut value = Vec::::with_capacity( + 2_usize + .saturating_add(diff.added.len()) + .saturating_add(diff.removed.len()), + ); + + let parent = diff.parent.unwrap_or(0_u64); + value.extend_from_slice(&parent.to_be_bytes()); + for new in diff.added.iter() { value.extend_from_slice(&new[..]); } From 9f7a4a012b38c4e3a59ea72bbb5d291ecdbf37f4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 15 Nov 2024 03:41:08 +0000 Subject: [PATCH 237/245] improve tracing/logging for state_compressor Signed-off-by: Jason Volk --- src/service/rooms/state_compressor/mod.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 06054f0d3..f0c851de9 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -7,7 +7,7 @@ use std::{ use arrayvec::ArrayVec; use conduit::{ - at, checked, err, expected, utils, + at, checked, debug, err, expected, utils, utils::{bytes, math::usize_from_f64}, Result, }; @@ -157,6 +157,13 @@ impl Service { }] }; + debug!( + ?parent, + ?shortstatehash, + vec_len = %response.len(), + "cache update" + ); + self.stateinfo_cache .lock() .expect("locked") @@ -218,7 +225,6 @@ impl Service { /// for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer - #[tracing::instrument(skip_all, level = "debug")] pub fn save_state_from_diff( &self, shortstatehash: ShortStateHash, statediffnew: Arc>, statediffremoved: Arc>, diff_to_sibling: usize, @@ -335,6 +341,7 @@ impl Service { /// Returns the new shortstatehash, and the state diff from the previous /// room state + #[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")] pub async fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, ) -> Result { @@ -405,6 +412,7 @@ impl Service { }) } + #[tracing::instrument(skip(self), level = "debug", name = "get")] async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result { const BUFSIZE: usize = size_of::(); const STRIDE: usize = size_of::(); From 14e3b242dfafbb31b9a9c58b586989a828a022f3 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 15 Nov 2024 03:44:04 +0000 Subject: [PATCH 238/245] add database get_batch stream wrapper Signed-off-by: Jason Volk --- src/database/map/get.rs | 22 +++++++++++++++------- src/service/rooms/auth_chain/mod.rs | 12 ++++++------ src/service/rooms/short/mod.rs | 13 +++++++------ 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/database/map/get.rs b/src/database/map/get.rs index 72382e367..2f7df0318 100644 --- a/src/database/map/get.rs +++ b/src/database/map/get.rs @@ -1,8 +1,8 @@ use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; use arrayvec::ArrayVec; -use conduit::{err, implement, Result}; -use futures::future::ready; +use conduit::{err, implement, utils::IterStream, Result}; +use futures::{future::ready, Stream}; use rocksdb::DBPinnableSlice; use serde::Serialize; @@ -50,6 +50,7 @@ where /// Fetch a value from the database into cache, returning a reference-handle /// asynchronously. The key is referenced directly to perform the query. #[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] pub fn get(&self, key: &K) -> impl Future>> + Send where K: AsRef<[u8]> + ?Sized + Debug, @@ -61,10 +62,9 @@ where /// The key is referenced directly to perform the query. This is a thread- /// blocking call. #[implement(super::Map)] -#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] pub fn get_blocking(&self, key: &K) -> Result> where - K: AsRef<[u8]> + ?Sized + Debug, + K: AsRef<[u8]> + ?Sized, { let res = self .db @@ -76,10 +76,19 @@ where #[implement(super::Map)] #[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] -pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec>> +pub fn get_batch<'a, I, K>(&self, keys: I) -> impl Stream>> where I: Iterator + ExactSizeIterator + Send + Debug, - K: AsRef<[u8]> + Sized + Debug + 'a, + K: AsRef<[u8]> + Send + Sync + Sized + Debug + 'a, +{ + self.get_batch_blocking(keys).stream() +} + +#[implement(super::Map)] +pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> impl Iterator>> +where + I: Iterator + ExactSizeIterator + Send, + K: AsRef<[u8]> + Sized + 'a, { // Optimization can be `true` if key vector is pre-sorted **by the column // comparator**. @@ -91,7 +100,6 @@ where .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) .into_iter() .map(into_result_handle) - .collect() } fn into_result_handle(result: RocksdbResult<'_>) -> Result> { diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index c22732c24..cabb6f0ca 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -6,7 +6,7 @@ use std::{ }; use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; -use futures::Stream; +use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId}; use self::data::Data; @@ -69,15 +69,15 @@ impl Service { const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); let started = std::time::Instant::now(); - let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, &short) in self + let mut starting_ids = self .services .short .multi_get_or_create_shorteventid(starting_events) - .await - .iter() .enumerate() - { + .boxed(); + + let mut buckets = [BUCKET; NUM_BUCKETS]; + while let Some((i, short)) = starting_ids.next().await { let bucket: usize = short.try_into()?; let bucket: usize = validated!(bucket % NUM_BUCKETS); buckets[bucket].insert((short, starting_events[i])); diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index e8b00d9bd..703df796a 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -3,6 +3,7 @@ use std::{mem::size_of_val, sync::Arc}; pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId}; use conduit::{err, implement, utils, Result}; use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{globals, Dep}; @@ -71,11 +72,12 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEvent } #[implement(Service)] -pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { +pub fn multi_get_or_create_shorteventid<'a>( + &'a self, event_ids: &'a [&EventId], +) -> impl Stream + Send + 'a { self.db .eventid_shorteventid - .get_batch_blocking(event_ids.iter()) - .into_iter() + .get_batch(event_ids.iter()) .enumerate() .map(|(i, result)| match result { Ok(ref short) => utils::u64_from_u8(short), @@ -95,7 +97,6 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> short }, }) - .collect() } #[implement(Service)] @@ -163,10 +164,10 @@ pub async fn multi_get_eventid_from_short(&self, shorteventid: &[ShortEventId]) self.db .shorteventid_eventid - .get_batch_blocking(keys.iter()) - .into_iter() + .get_batch(keys.iter()) .map(Deserialized::deserialized) .collect() + .await } #[implement(Service)] From 887ae84f1e3b3e0254e04afe011083e692af7e00 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 15 Nov 2024 22:20:28 +0000 Subject: [PATCH 239/245] optimize sha256 interface gather/vector inputs Signed-off-by: Jason Volk --- src/api/server/invite.rs | 4 +- src/core/utils/hash.rs | 9 +-- src/core/utils/hash/sha256.rs | 69 ++++++++++++++++++++--- src/core/utils/mod.rs | 2 +- src/service/rooms/state/mod.rs | 7 +-- src/service/rooms/state_compressor/mod.rs | 7 +-- src/service/sending/sender.rs | 39 ++++++------- 7 files changed, 86 insertions(+), 51 deletions(-) diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index edf80cd69..0ceb914fc 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -1,7 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use base64::{engine::general_purpose, Engine as _}; -use conduit::{err, utils, warn, Err, Error, PduEvent, Result}; +use conduit::{err, utils, utils::hash::sha256, warn, Err, Error, PduEvent, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_invite}, events::room::member::{MembershipState, RoomMemberEventContent}, @@ -160,7 +160,7 @@ pub(crate) async fn create_invite_route( ruma::api::appservice::event::push_events::v1::Request { events: vec![pdu.to_room_event()], txn_id: general_purpose::URL_SAFE_NO_PAD - .encode(utils::calculate_hash(&[pdu.event_id.as_bytes()])) + .encode(sha256::hash(pdu.event_id.as_bytes())) .into(), ephemeral: Vec::new(), to_device: Vec::new(), diff --git a/src/core/utils/hash.rs b/src/core/utils/hash.rs index 5a3664cb6..c12d4f663 100644 --- a/src/core/utils/hash.rs +++ b/src/core/utils/hash.rs @@ -1,13 +1,10 @@ mod argon; -mod sha256; +pub mod sha256; use crate::Result; -pub fn password(password: &str) -> Result { argon::password(password) } - -pub fn verify_password(password: &str, password_hash: &str) -> Result<()> { +pub fn verify_password(password: &str, password_hash: &str) -> Result { argon::verify_password(password, password_hash) } -#[must_use] -pub fn calculate_hash(keys: &[&[u8]]) -> Vec { sha256::hash(keys) } +pub fn password(password: &str) -> Result { argon::password(password) } diff --git a/src/core/utils/hash/sha256.rs b/src/core/utils/hash/sha256.rs index b2e5a94c2..06e210a7e 100644 --- a/src/core/utils/hash/sha256.rs +++ b/src/core/utils/hash/sha256.rs @@ -1,9 +1,62 @@ -use ring::{digest, digest::SHA256}; - -#[tracing::instrument(skip_all, level = "debug")] -pub(super) fn hash(keys: &[&[u8]]) -> Vec { - // We only hash the pdu's event ids, not the whole pdu - let bytes = keys.join(&0xFF); - let hash = digest::digest(&SHA256, &bytes); - hash.as_ref().to_owned() +use ring::{ + digest, + digest::{Context, SHA256, SHA256_OUTPUT_LEN}, +}; + +pub type Digest = [u8; SHA256_OUTPUT_LEN]; + +/// Sha256 hash (input gather joined by 0xFF bytes) +#[must_use] +#[tracing::instrument(skip(inputs), level = "trace")] +pub fn delimited<'a, T, I>(mut inputs: I) -> Digest +where + I: Iterator + 'a, + T: AsRef<[u8]> + 'a, +{ + let mut ctx = Context::new(&SHA256); + if let Some(input) = inputs.next() { + ctx.update(input.as_ref()); + for input in inputs { + ctx.update(b"\xFF"); + ctx.update(input.as_ref()); + } + } + + ctx.finish() + .as_ref() + .try_into() + .expect("failed to return Digest buffer") +} + +/// Sha256 hash (input gather) +#[must_use] +#[tracing::instrument(skip(inputs), level = "trace")] +pub fn concat<'a, T, I>(inputs: I) -> Digest +where + I: Iterator + 'a, + T: AsRef<[u8]> + 'a, +{ + inputs + .fold(Context::new(&SHA256), |mut ctx, input| { + ctx.update(input.as_ref()); + ctx + }) + .finish() + .as_ref() + .try_into() + .expect("failed to return Digest buffer") +} + +/// Sha256 hash +#[inline] +#[must_use] +#[tracing::instrument(skip(input), level = "trace")] +pub fn hash(input: T) -> Digest +where + T: AsRef<[u8]>, +{ + digest::digest(&SHA256, input.as_ref()) + .as_ref() + .try_into() + .expect("failed to return Digest buffer") } diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index b8640f3af..18c2dd6f3 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -28,7 +28,7 @@ pub use self::{ bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}, debug::slice_truncated as debug_slice_truncated, future::TryExtExt as TryFutureExtExt, - hash::calculate_hash, + hash::sha256::delimited as calculate_hash, html::Escape as HtmlEscape, json::{deserialize_from_str, to_canonical_object}, math::clamp, diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 7d8200f09..29ffedfce 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -157,12 +157,7 @@ impl Service { let previous_shortstatehash = self.get_room_shortstatehash(room_id).await; - let state_hash = calculate_hash( - &state_ids_compressed - .iter() - .map(|s| &s[..]) - .collect::>(), - ); + let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..])); let (shortstatehash, already_existed) = self .services diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index f0c851de9..0466fb125 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -352,12 +352,7 @@ impl Service { .await .ok(); - let state_hash = utils::calculate_hash( - &new_state_ids_compressed - .iter() - .map(|bytes| &bytes[..]) - .collect::>(), - ); + let state_hash = utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..])); let (new_shortstatehash, already_existed) = self .services diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index f5d875045..ee8182895 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -539,16 +539,13 @@ impl Service { } } - let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) => &**b, - SendingEvent::Pdu(b) => b.as_ref(), - SendingEvent::Flush => &[], - }) - .collect::>(), - )); + let txn_hash = calculate_hash(events.iter().filter_map(|e| match e { + SendingEvent::Edu(b) => Some(&**b), + SendingEvent::Pdu(b) => Some(b.as_ref()), + SendingEvent::Flush => None, + })); + + let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash); //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty // transaction"); @@ -664,23 +661,21 @@ impl Service { //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty // transaction"); - let transaction_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) => &**b, - SendingEvent::Pdu(b) => b.as_ref(), - SendingEvent::Flush => &[], - }) - .collect::>(), - )); + + let txn_hash = calculate_hash(events.iter().filter_map(|e| match e { + SendingEvent::Edu(b) => Some(&**b), + SendingEvent::Pdu(b) => Some(b.as_ref()), + SendingEvent::Flush => None, + })); + + let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash); let request = send_transaction_message::v1::Request { origin: self.server.config.server_name.clone(), pdus: pdu_jsons, edus: edu_jsons, origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - transaction_id: transaction_id.into(), + transaction_id: txn_id.into(), }; let client = &self.services.client.sender; @@ -692,7 +687,7 @@ impl Service { .iter() .filter(|(_, res)| res.is_err()) .for_each( - |(pdu_id, res)| warn!(%transaction_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"), + |(pdu_id, res)| warn!(%txn_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"), ); }) .map(|_| dest.clone()) From cd2c473bfe627389f7e24f822ca4a19e696ca555 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 21:00:26 -0500 Subject: [PATCH 240/245] add missing fix_referencedevents_missing_sep key on fresh db creations Signed-off-by: strawberry --- src/service/migrations.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/service/migrations.rs b/src/service/migrations.rs index 4c821fa38..126d3c7ef 100644 --- a/src/service/migrations.rs +++ b/src/service/migrations.rs @@ -68,6 +68,7 @@ async fn fresh(services: &Services) -> Result<()> { db["global"].insert(b"feat_sha256_media", []); db["global"].insert(b"fix_bad_double_separator_in_state_cache", []); db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); + db["global"].insert(b"fix_referencedevents_missing_sep", []); // Create the admin room and server user on first run crate::admin::create_admin_room(services).boxed().await?; From ead9d667970f77c6d4e1c9747e607d7d711a57e0 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 21:28:08 -0500 Subject: [PATCH 241/245] send the actual unsupported room version in join errors Signed-off-by: strawberry --- src/api/client/membership.rs | 69 ++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index c61185a7c..9478e383d 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -702,18 +702,20 @@ async fn join_room_by_id_helper_remote( info!("make_join finished"); - let room_version_id = match make_join_response.room_version { - Some(room_version) - if services - .globals - .supported_room_versions() - .contains(&room_version) => - { - room_version - }, - _ => return Err!(BadServerResponse("Room version is not supported")), + let Some(room_version_id) = make_join_response.room_version else { + return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); }; + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) .map_err(|e| err!(BadServerResponse("Invalid make_join event json received from server: {e:?}")))?; @@ -1116,17 +1118,20 @@ async fn join_room_by_id_helper_local( warn!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); let (make_join_response, remote_server) = make_join_request(services, sender_user, room_id, servers).await?; - let room_version_id = match make_join_response.room_version { - Some(room_version_id) - if services - .globals - .supported_room_versions() - .contains(&room_version_id) => - { - room_version_id - }, - _ => return Err!(BadServerResponse("Room version is not supported")), + let Some(room_version_id) = make_join_response.room_version else { + return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); }; + + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) .map_err(|e| err!(BadServerResponse("Invalid make_join event json received from server: {e:?}")))?; let join_authorized_via_users_server = join_event_stub @@ -1274,7 +1279,7 @@ async fn make_join_request( if incompatible_room_version_count > 15 { info!( "15 servers have responded with M_INCOMPATIBLE_ROOM_VERSION or M_UNSUPPORTED_ROOM_VERSION, \ - assuming that conduwuit does not support the room {room_id}: {e}" + assuming that conduwuit does not support the room version {room_id}: {e}" ); make_join_response_and_server = Err!(BadServerResponse("Room version is not supported by Conduwuit")); return make_join_response_and_server; @@ -1607,18 +1612,20 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room let (make_leave_response, remote_server) = make_leave_response_and_server?; - let room_version_id = match make_leave_response.room_version { - Some(version) - if services - .globals - .supported_room_versions() - .contains(&version) => - { - version - }, - _ => return Err!(BadServerResponse("Room version is not supported")), + let Some(room_version_id) = make_leave_response.room_version else { + return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); }; + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + let mut leave_event_stub = serde_json::from_str::(make_leave_response.event.get()) .map_err(|e| err!(BadServerResponse("Invalid make_leave event json received from server: {e:?}")))?; From 63193840729ab1dea026514d807aa39a98b3f1b1 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 21:40:13 -0500 Subject: [PATCH 242/245] implement `GET /_matrix/client/v3/pushrules/global/` Signed-off-by: strawberry --- src/api/client/push.rs | 69 +++++++++++++++++++++++++++++++++++++++++- src/api/router.rs | 1 + 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/api/client/push.rs b/src/api/client/push.rs index de280b32f..97243ab45 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -5,7 +5,7 @@ use ruma::{ error::ErrorKind, push::{ delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all, - set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, + get_pushrules_global_scope, set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, }, }, events::{ @@ -67,6 +67,73 @@ pub(crate) async fn get_pushrules_all_route( }) } +/// # `GET /_matrix/client/r0/pushrules/global/` +/// +/// Retrieves the push rules event for this user. +/// +/// This appears to be the exact same as `GET /_matrix/client/r0/pushrules/`. +pub(crate) async fn get_pushrules_global_route( + State(services): State, body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let Some(content_value) = services + .account_data + .get_global::(sender_user, GlobalAccountDataEventType::PushRules) + .await + .ok() + .and_then(|event| event.get("content").cloned()) + .filter(CanonicalJsonValue::is_object) + else { + // user somehow has non-existent push rule event. recreate it and return server + // default silently + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) + .await?; + + return Ok(get_pushrules_global_scope::v3::Response { + global: Ruleset::server_default(sender_user), + }); + }; + + let account_data_content = serde_json::from_value::(content_value.into()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + + let mut global_ruleset = account_data_content.global; + + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + { + use ruma::push::RuleKind::*; + + global_ruleset + .remove(Override, PredefinedOverrideRuleId::ContainsDisplayName) + .ok(); + global_ruleset + .remove(Override, PredefinedOverrideRuleId::RoomNotif) + .ok(); + + global_ruleset + .remove(Content, PredefinedContentRuleId::ContainsUserName) + .ok(); + }; + + Ok(get_pushrules_global_scope::v3::Response { + global: global_ruleset, + }) +} + /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// /// Retrieves a single specified push rule for this user. diff --git a/src/api/router.rs b/src/api/router.rs index ddd91d11f..1df4342fe 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -45,6 +45,7 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::check_registration_token_validity) .ruma_route(&client::get_capabilities_route) .ruma_route(&client::get_pushrules_all_route) + .ruma_route(&client::get_pushrules_global_route) .ruma_route(&client::set_pushrule_route) .ruma_route(&client::get_pushrule_route) .ruma_route(&client::set_pushrule_enabled_route) From b92b4e043c03dfcc6f0163af92a219251b745351 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 22:16:11 -0500 Subject: [PATCH 243/245] drop hyper-util back down to 0.1.8 due to DNS issues Signed-off-by: strawberry --- Cargo.lock | 17 +++++++++-------- Cargo.toml | 3 ++- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b56005ff3..3a95f83a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,9 +164,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-lc-rs" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdd82dba44d209fddb11c190e0a94b78651f95299598e472215667417a03ff1d" +checksum = "fe7c2840b66236045acd2607d5866e274380afd87ef99d6226e961e2cb47df45" dependencies = [ "aws-lc-sys", "mirai-annotations", @@ -176,9 +176,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.22.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df7a4168111d7eb622a31b214057b8509c0a7e1794f44c546d742330dc793972" +checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96" dependencies = [ "bindgen", "cc", @@ -1739,9 +1739,9 @@ dependencies = [ [[package]] name = "hyper-timeout" -version = "0.5.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793" dependencies = [ "hyper", "hyper-util", @@ -1752,9 +1752,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" dependencies = [ "bytes", "futures-channel", @@ -1765,6 +1765,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", + "tower 0.4.13", "tower-service", "tracing", ] diff --git a/Cargo.toml b/Cargo.toml index 814a435b2..68c87c572 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -264,7 +264,8 @@ features = [ ] [workspace.dependencies.hyper-util] -version = "0.1.10" +# hyper-util >=0.1.9 seems to have DNS issues +version = "=0.1.8" default-features = false features = [ "server-auto", From 7f96b2f92ad11e9a357c5c0607f7be403f7c8a85 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 23:18:12 -0500 Subject: [PATCH 244/245] nix: remove libllvm, libgcc, and llvm from OCI images as well aarch64 OCI images love llvm?? Signed-off-by: strawberry --- nix/pkgs/main/default.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nix/pkgs/main/default.nix b/nix/pkgs/main/default.nix index 1088b03cd..d11c36cc5 100644 --- a/nix/pkgs/main/default.nix +++ b/nix/pkgs/main/default.nix @@ -176,7 +176,7 @@ commonAttrs = { # # postInstall = with pkgsBuildHost; '' - find "$out" -type f -exec remove-references-to -t ${stdenv.cc} -t ${gcc} -t ${rustc.unwrapped} -t ${rustc} -t ${libidn2} -t ${libunistring} '{}' + + find "$out" -type f -exec remove-references-to -t ${stdenv.cc} -t ${gcc} -t ${libgcc} -t ${llvm} -t ${libllvm} -t ${rustc.unwrapped} -t ${rustc} -t ${libidn2} -t ${libunistring} '{}' + ''; }; in From 8f140485287adb8534d299ce553d15994ba1fab7 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 15 Nov 2024 23:48:55 -0500 Subject: [PATCH 245/245] ci: free up a bit of runner space safely (again) Signed-off-by: strawberry --- .github/workflows/ci.yml | 9 +++++++++ .github/workflows/documentation.yml | 10 ++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 96a1188b6..9385c5e3b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,6 +71,15 @@ jobs: run: | sudo apt install liburing-dev -y + - name: Free up a bit of runner space + run: | + set +o pipefail + sudo docker image prune --all --force || true + sudo apt purge -y 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true + sudo apt clean + sudo rm -v -rf /usr/local/games /usr/local/sqlpackage /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/heroku + set -o pipefail + - name: Sync repository uses: actions/checkout@v4 diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index ea720c43c..17b1f9c17 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -50,8 +50,14 @@ jobs: url: ${{ steps.deployment.outputs.page_url }} steps: - - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@main + - name: Free up a bit of runner space + run: | + set +o pipefail + sudo docker image prune --all --force || true + sudo apt purge -y 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true + sudo apt clean + sudo rm -v -rf /usr/local/games /usr/local/sqlpackage /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/heroku + set -o pipefail - name: Sync repository uses: actions/checkout@v4