Skip to content

Commit

Permalink
Call cache directly from publisher
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierHecart committed Dec 5, 2024
1 parent 7537985 commit f850eaa
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 193 deletions.
280 changes: 103 additions & 177 deletions zenoh-ext/src/advanced_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,20 @@
// ZettaScale Zenoh Team, <[email protected]>
//
use std::{
borrow::Borrow,
collections::{HashMap, VecDeque},
collections::VecDeque,
future::{IntoFuture, Ready},
sync::{Arc, RwLock},
};

use flume::{bounded, Sender};
use futures::{select, FutureExt};
use tokio::task;
use zenoh::{
handlers::FifoChannelHandler,
internal::{bail, traits::QoSBuilderTrait},
key_expr::{
format::{ke, kedefine},
keyexpr, KeyExpr, OwnedKeyExpr,
keyexpr, KeyExpr,
},
liveliness::LivelinessToken,
pubsub::Subscriber,
qos::{CongestionControl, Priority},
query::{Query, Queryable, ZenohParameters},
query::{Queryable, ZenohParameters},
sample::{Locality, Sample, SampleBuilder},
Resolvable, Result as ZResult, Session, Wait, KE_ADV_PREFIX, KE_AT, KE_STARSTAR,
};
Expand Down Expand Up @@ -82,16 +77,14 @@ impl QoSBuilderTrait for QoS {
#[derive(Debug, Clone)]
/// Configure an [`AdvancedPublisher`](crate::AdvancedPublisher) cache.
pub struct CacheConfig {
sample_depth: usize,
resources_limit: Option<usize>,
max_samples: usize,
replies_qos: QoS,
}

impl Default for CacheConfig {
fn default() -> Self {
Self {
sample_depth: 1,
resources_limit: None,
max_samples: 1,
replies_qos: QoS::default(),
}
}
Expand All @@ -100,15 +93,7 @@ impl Default for CacheConfig {
impl CacheConfig {
/// Specify how many samples to keep for each resource.
pub fn max_samples(mut self, depth: usize) -> Self {
self.sample_depth = depth;
self
}

// TODO pub fn max_age(mut self, depth: Duration) -> Self

/// Specify the maximum total number of samples to keep.
pub fn max_total_samples(mut self, limit: usize) -> Self {
self.resources_limit = Some(limit);
self.max_samples = depth;
self
}

Expand All @@ -124,7 +109,6 @@ pub struct AdvancedCacheBuilder<'a, 'b, 'c> {
session: &'a Session,
pub_key_expr: ZResult<KeyExpr<'b>>,
queryable_prefix: Option<ZResult<KeyExpr<'c>>>,
subscriber_origin: Locality,
queryable_origin: Locality,
history: CacheConfig,
liveliness: bool,
Expand All @@ -139,7 +123,6 @@ impl<'a, 'b, 'c> AdvancedCacheBuilder<'a, 'b, 'c> {
session,
pub_key_expr,
queryable_prefix: Some(Ok((KE_ADV_PREFIX / KE_STARSTAR / KE_AT).into())),
subscriber_origin: Locality::default(),
queryable_origin: Locality::default(),
history: CacheConfig::default(),
liveliness: false,
Expand All @@ -156,14 +139,6 @@ impl<'a, 'b, 'c> AdvancedCacheBuilder<'a, 'b, 'c> {
self
}

/// Restrict the matching publications that will be cached by this [`AdvancedCache`]
/// to the ones that have the given [`Locality`](zenoh::sample::Locality).
#[inline]
pub fn subscriber_allowed_origin(mut self, origin: Locality) -> Self {
self.subscriber_origin = origin;
self
}

/// Change the history size for each resource.
pub fn history(mut self, history: CacheConfig) -> Self {
self.history = history;
Expand Down Expand Up @@ -213,174 +188,114 @@ fn sample_in_range(sample: &Sample, start: Option<u32>, end: Option<u32>) -> boo
}

pub struct AdvancedCache {
_sub: Subscriber<FifoChannelHandler<Sample>>,
_queryable: Queryable<FifoChannelHandler<Query>>,
cache: Arc<RwLock<VecDeque<Sample>>>,
max_samples: usize,
_queryable: Queryable<()>,
_token: Option<LivelinessToken>,
_stoptx: Sender<bool>,
}

impl AdvancedCache {
fn new(conf: AdvancedCacheBuilder<'_, '_, '_>) -> ZResult<AdvancedCache> {
let key_expr = conf.pub_key_expr?;
let key_expr = conf.pub_key_expr?.into_owned();
// the queryable_prefix (optional), and the key_expr for AdvancedCache's queryable ("[<queryable_prefix>]/<pub_key_expr>")
let (queryable_prefix, queryable_key_expr): (Option<OwnedKeyExpr>, KeyExpr) =
match conf.queryable_prefix {
None => (None, key_expr.clone()),
Some(Ok(ke)) => {
let queryable_key_expr = (&ke) / &key_expr;
(Some(ke.into()), queryable_key_expr)
}
Some(Err(e)) => bail!("Invalid key expression for queryable_prefix: {}", e),
};
let queryable_key_expr = match conf.queryable_prefix {
None => key_expr.clone(),
Some(Ok(ke)) => (&ke) / &key_expr,
Some(Err(e)) => bail!("Invalid key expression for queryable_prefix: {}", e),
};
tracing::debug!(
"Create AdvancedCache on {} with history={:?}",
"Create AdvancedCache on {} with max_samples={:?}",
&key_expr,
conf.history,
);

// declare the local subscriber that will store the local publications
let sub = conf
.session
.declare_subscriber(&key_expr)
.allowed_origin(conf.subscriber_origin)
.wait()?;
let cache = Arc::new(RwLock::new(VecDeque::new()));

// declare the queryable that will answer to queries on cache
let queryable = conf
.session
.declare_queryable(&queryable_key_expr)
.allowed_origin(conf.queryable_origin)
.wait()?;

// take local ownership of stuff to be moved into task
let sub_recv = sub.handler().clone();
let quer_recv = queryable.handler().clone();
let pub_key_expr = key_expr.into_owned();
let history = conf.history;

let (stoptx, stoprx) = bounded::<bool>(1);
task::spawn(async move {
async fn process_queue(
queue: &VecDeque<Sample>,
query: &Query,
start: Option<u32>,
end: Option<u32>,
max: Option<u32>,
qos: &QoS,
) {
if let Some(max) = max {
let mut samples = VecDeque::new();
for sample in queue {
if sample_in_range(sample, start, end) {
if let (Some(Ok(time_range)), Some(timestamp)) =
(query.parameters().time_range(), sample.timestamp())
{
if !time_range.contains(timestamp.get_time().to_system_time()) {
continue;
.callback({
let cache = cache.clone();
move |query| {
let (start, end) = query
.parameters()
.get("_sn")
.map(decode_range)
.unwrap_or((None, None));
let max = query
.parameters()
.get("_max")
.and_then(|s| s.parse::<u32>().ok());
if let Ok(queue) = cache.read() {
if let Some(max) = max {
let mut samples = VecDeque::new();
for sample in queue.iter() {
if sample_in_range(sample, start, end) {
if let (Some(Ok(time_range)), Some(timestamp)) =
(query.parameters().time_range(), sample.timestamp())
{
if !time_range
.contains(timestamp.get_time().to_system_time())
{
continue;
}
}
samples.push_front(sample);
samples.truncate(max as usize);
}
}
samples.push_front(sample);
samples.truncate(max as usize);
}
}
for sample in samples.drain(..).rev() {
if let Err(e) = query
.reply_sample(
SampleBuilder::from(sample.clone())
.congestion_control(qos.congestion_control)
.priority(qos.priority)
.express(qos.is_express)
.into(),
)
.await
{
tracing::warn!("Error replying to query: {}", e);
}
}
} else {
for sample in queue {
if sample_in_range(sample, start, end) {
if let (Some(Ok(time_range)), Some(timestamp)) =
(query.parameters().time_range(), sample.timestamp())
{
if !time_range.contains(timestamp.get_time().to_system_time()) {
continue;
for sample in samples.drain(..).rev() {
if let Err(e) = query
.reply_sample(
SampleBuilder::from(sample.clone())
.congestion_control(
conf.history.replies_qos.congestion_control,
)
.priority(conf.history.replies_qos.priority)
.express(conf.history.replies_qos.is_express)
.into(),
)
.wait()
{
tracing::warn!("Error replying to query: {}", e);
}
}
if let Err(e) = query
.reply_sample(
SampleBuilder::from(sample.clone())
.congestion_control(qos.congestion_control)
.priority(qos.priority)
.express(qos.is_express)
.into(),
)
.await
{
tracing::warn!("Error replying to query: {}", e);
}
}
}
}
}

let mut cache: HashMap<OwnedKeyExpr, VecDeque<Sample>> =
HashMap::with_capacity(history.resources_limit.unwrap_or(32));
let limit = history.resources_limit.unwrap_or(usize::MAX);

loop {
select!(
// on publication received by the local subscriber, store it
sample = sub_recv.recv_async() => {
if let Ok(sample) = sample {
let queryable_key_expr: KeyExpr<'_> = if let Some(prefix) = &queryable_prefix {
prefix.join(&sample.key_expr()).unwrap().into()
} else {
sample.key_expr().clone()
};

if let Some(queue) = cache.get_mut(queryable_key_expr.as_keyexpr()) {
if queue.len() >= history.sample_depth {
queue.pop_front();
}
queue.push_back(sample);
} else if cache.len() >= limit {
tracing::error!("AdvancedCache on {}: resource_limit exceeded - can't cache publication for a new resource",
pub_key_expr);
} else {
let mut queue: VecDeque<Sample> = VecDeque::new();
queue.push_back(sample);
cache.insert(queryable_key_expr.into(), queue);
}
}
},

// on query, reply with cache content
query = quer_recv.recv_async() => {
if let Ok(query) = query {
let (start, end) = query.parameters().get("_sn").map(decode_range).unwrap_or((None, None));
let max = query.parameters().get("_max").and_then(|s| s.parse::<u32>().ok());
if !query.selector().key_expr().as_str().contains('*') {
if let Some(queue) = cache.get(query.selector().key_expr().as_keyexpr()) {
process_queue(queue, &query, start, end, max, &history.replies_qos).await;
}
} else {
for (key_expr, queue) in cache.iter() {
if query.selector().key_expr().intersects(key_expr.borrow()) {
process_queue(queue, &query, start, end, max, &history.replies_qos).await;
} else {
for sample in queue.iter() {
if sample_in_range(sample, start, end) {
if let (Some(Ok(time_range)), Some(timestamp)) =
(query.parameters().time_range(), sample.timestamp())
{
if !time_range
.contains(timestamp.get_time().to_system_time())
{
continue;
}
}
if let Err(e) = query
.reply_sample(
SampleBuilder::from(sample.clone())
.congestion_control(
conf.history.replies_qos.congestion_control,
)
.priority(conf.history.replies_qos.priority)
.express(conf.history.replies_qos.is_express)
.into(),
)
.wait()
{
tracing::warn!("Error replying to query: {}", e);
}
}
}
}
},

// When stoptx is dropped, stop the task
_ = stoprx.recv_async().fuse() => {
return
} else {
tracing::error!("Unable to take AdvancedPublisher cache read lock");
}
);
}
});
}
})
.wait()?;

let token = if conf.liveliness {
Some(
Expand All @@ -394,10 +309,21 @@ impl AdvancedCache {
};

Ok(AdvancedCache {
_sub: sub,
cache,
max_samples: conf.history.max_samples,
_queryable: queryable,
_token: token,
_stoptx: stoptx,
})
}

pub fn cache_sample(&self, sample: Sample) {
if let Ok(mut queue) = self.cache.write() {
if queue.len() >= self.max_samples {
queue.pop_front();
}
queue.push_back(sample);
} else {
tracing::error!("Unable to take AdvancedPublisher cache write lock");
}
}
}
Loading

0 comments on commit f850eaa

Please sign in to comment.