-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7537985
commit f850eaa
Showing
5 changed files
with
283 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,25 +12,20 @@ | |
// ZettaScale Zenoh Team, <[email protected]> | ||
// | ||
use std::{ | ||
borrow::Borrow, | ||
collections::{HashMap, VecDeque}, | ||
collections::VecDeque, | ||
future::{IntoFuture, Ready}, | ||
sync::{Arc, RwLock}, | ||
}; | ||
|
||
use flume::{bounded, Sender}; | ||
use futures::{select, FutureExt}; | ||
use tokio::task; | ||
use zenoh::{ | ||
handlers::FifoChannelHandler, | ||
internal::{bail, traits::QoSBuilderTrait}, | ||
key_expr::{ | ||
format::{ke, kedefine}, | ||
keyexpr, KeyExpr, OwnedKeyExpr, | ||
keyexpr, KeyExpr, | ||
}, | ||
liveliness::LivelinessToken, | ||
pubsub::Subscriber, | ||
qos::{CongestionControl, Priority}, | ||
query::{Query, Queryable, ZenohParameters}, | ||
query::{Queryable, ZenohParameters}, | ||
sample::{Locality, Sample, SampleBuilder}, | ||
Resolvable, Result as ZResult, Session, Wait, KE_ADV_PREFIX, KE_AT, KE_STARSTAR, | ||
}; | ||
|
@@ -82,16 +77,14 @@ impl QoSBuilderTrait for QoS { | |
#[derive(Debug, Clone)] | ||
/// Configure an [`AdvancedPublisher`](crate::AdvancedPublisher) cache. | ||
pub struct CacheConfig { | ||
sample_depth: usize, | ||
resources_limit: Option<usize>, | ||
max_samples: usize, | ||
replies_qos: QoS, | ||
} | ||
|
||
impl Default for CacheConfig { | ||
fn default() -> Self { | ||
Self { | ||
sample_depth: 1, | ||
resources_limit: None, | ||
max_samples: 1, | ||
replies_qos: QoS::default(), | ||
} | ||
} | ||
|
@@ -100,15 +93,7 @@ impl Default for CacheConfig { | |
impl CacheConfig { | ||
/// Specify how many samples to keep for each resource. | ||
pub fn max_samples(mut self, depth: usize) -> Self { | ||
self.sample_depth = depth; | ||
self | ||
} | ||
|
||
// TODO pub fn max_age(mut self, depth: Duration) -> Self | ||
|
||
/// Specify the maximum total number of samples to keep. | ||
pub fn max_total_samples(mut self, limit: usize) -> Self { | ||
self.resources_limit = Some(limit); | ||
self.max_samples = depth; | ||
self | ||
} | ||
|
||
|
@@ -124,7 +109,6 @@ pub struct AdvancedCacheBuilder<'a, 'b, 'c> { | |
session: &'a Session, | ||
pub_key_expr: ZResult<KeyExpr<'b>>, | ||
queryable_prefix: Option<ZResult<KeyExpr<'c>>>, | ||
subscriber_origin: Locality, | ||
queryable_origin: Locality, | ||
history: CacheConfig, | ||
liveliness: bool, | ||
|
@@ -139,7 +123,6 @@ impl<'a, 'b, 'c> AdvancedCacheBuilder<'a, 'b, 'c> { | |
session, | ||
pub_key_expr, | ||
queryable_prefix: Some(Ok((KE_ADV_PREFIX / KE_STARSTAR / KE_AT).into())), | ||
subscriber_origin: Locality::default(), | ||
queryable_origin: Locality::default(), | ||
history: CacheConfig::default(), | ||
liveliness: false, | ||
|
@@ -156,14 +139,6 @@ impl<'a, 'b, 'c> AdvancedCacheBuilder<'a, 'b, 'c> { | |
self | ||
} | ||
|
||
/// Restrict the matching publications that will be cached by this [`AdvancedCache`] | ||
/// to the ones that have the given [`Locality`](zenoh::sample::Locality). | ||
#[inline] | ||
pub fn subscriber_allowed_origin(mut self, origin: Locality) -> Self { | ||
self.subscriber_origin = origin; | ||
self | ||
} | ||
|
||
/// Change the history size for each resource. | ||
pub fn history(mut self, history: CacheConfig) -> Self { | ||
self.history = history; | ||
|
@@ -213,174 +188,114 @@ fn sample_in_range(sample: &Sample, start: Option<u32>, end: Option<u32>) -> boo | |
} | ||
|
||
pub struct AdvancedCache { | ||
_sub: Subscriber<FifoChannelHandler<Sample>>, | ||
_queryable: Queryable<FifoChannelHandler<Query>>, | ||
cache: Arc<RwLock<VecDeque<Sample>>>, | ||
max_samples: usize, | ||
_queryable: Queryable<()>, | ||
_token: Option<LivelinessToken>, | ||
_stoptx: Sender<bool>, | ||
} | ||
|
||
impl AdvancedCache { | ||
fn new(conf: AdvancedCacheBuilder<'_, '_, '_>) -> ZResult<AdvancedCache> { | ||
let key_expr = conf.pub_key_expr?; | ||
let key_expr = conf.pub_key_expr?.into_owned(); | ||
// the queryable_prefix (optional), and the key_expr for AdvancedCache's queryable ("[<queryable_prefix>]/<pub_key_expr>") | ||
let (queryable_prefix, queryable_key_expr): (Option<OwnedKeyExpr>, KeyExpr) = | ||
match conf.queryable_prefix { | ||
None => (None, key_expr.clone()), | ||
Some(Ok(ke)) => { | ||
let queryable_key_expr = (&ke) / &key_expr; | ||
(Some(ke.into()), queryable_key_expr) | ||
} | ||
Some(Err(e)) => bail!("Invalid key expression for queryable_prefix: {}", e), | ||
}; | ||
let queryable_key_expr = match conf.queryable_prefix { | ||
None => key_expr.clone(), | ||
Some(Ok(ke)) => (&ke) / &key_expr, | ||
Some(Err(e)) => bail!("Invalid key expression for queryable_prefix: {}", e), | ||
}; | ||
tracing::debug!( | ||
"Create AdvancedCache on {} with history={:?}", | ||
"Create AdvancedCache on {} with max_samples={:?}", | ||
&key_expr, | ||
conf.history, | ||
); | ||
|
||
// declare the local subscriber that will store the local publications | ||
let sub = conf | ||
.session | ||
.declare_subscriber(&key_expr) | ||
.allowed_origin(conf.subscriber_origin) | ||
.wait()?; | ||
let cache = Arc::new(RwLock::new(VecDeque::new())); | ||
|
||
// declare the queryable that will answer to queries on cache | ||
let queryable = conf | ||
.session | ||
.declare_queryable(&queryable_key_expr) | ||
.allowed_origin(conf.queryable_origin) | ||
.wait()?; | ||
|
||
// take local ownership of stuff to be moved into task | ||
let sub_recv = sub.handler().clone(); | ||
let quer_recv = queryable.handler().clone(); | ||
let pub_key_expr = key_expr.into_owned(); | ||
let history = conf.history; | ||
|
||
let (stoptx, stoprx) = bounded::<bool>(1); | ||
task::spawn(async move { | ||
async fn process_queue( | ||
queue: &VecDeque<Sample>, | ||
query: &Query, | ||
start: Option<u32>, | ||
end: Option<u32>, | ||
max: Option<u32>, | ||
qos: &QoS, | ||
) { | ||
if let Some(max) = max { | ||
let mut samples = VecDeque::new(); | ||
for sample in queue { | ||
if sample_in_range(sample, start, end) { | ||
if let (Some(Ok(time_range)), Some(timestamp)) = | ||
(query.parameters().time_range(), sample.timestamp()) | ||
{ | ||
if !time_range.contains(timestamp.get_time().to_system_time()) { | ||
continue; | ||
.callback({ | ||
let cache = cache.clone(); | ||
move |query| { | ||
let (start, end) = query | ||
.parameters() | ||
.get("_sn") | ||
.map(decode_range) | ||
.unwrap_or((None, None)); | ||
let max = query | ||
.parameters() | ||
.get("_max") | ||
.and_then(|s| s.parse::<u32>().ok()); | ||
if let Ok(queue) = cache.read() { | ||
if let Some(max) = max { | ||
let mut samples = VecDeque::new(); | ||
for sample in queue.iter() { | ||
if sample_in_range(sample, start, end) { | ||
if let (Some(Ok(time_range)), Some(timestamp)) = | ||
(query.parameters().time_range(), sample.timestamp()) | ||
{ | ||
if !time_range | ||
.contains(timestamp.get_time().to_system_time()) | ||
{ | ||
continue; | ||
} | ||
} | ||
samples.push_front(sample); | ||
samples.truncate(max as usize); | ||
} | ||
} | ||
samples.push_front(sample); | ||
samples.truncate(max as usize); | ||
} | ||
} | ||
for sample in samples.drain(..).rev() { | ||
if let Err(e) = query | ||
.reply_sample( | ||
SampleBuilder::from(sample.clone()) | ||
.congestion_control(qos.congestion_control) | ||
.priority(qos.priority) | ||
.express(qos.is_express) | ||
.into(), | ||
) | ||
.await | ||
{ | ||
tracing::warn!("Error replying to query: {}", e); | ||
} | ||
} | ||
} else { | ||
for sample in queue { | ||
if sample_in_range(sample, start, end) { | ||
if let (Some(Ok(time_range)), Some(timestamp)) = | ||
(query.parameters().time_range(), sample.timestamp()) | ||
{ | ||
if !time_range.contains(timestamp.get_time().to_system_time()) { | ||
continue; | ||
for sample in samples.drain(..).rev() { | ||
if let Err(e) = query | ||
.reply_sample( | ||
SampleBuilder::from(sample.clone()) | ||
.congestion_control( | ||
conf.history.replies_qos.congestion_control, | ||
) | ||
.priority(conf.history.replies_qos.priority) | ||
.express(conf.history.replies_qos.is_express) | ||
.into(), | ||
) | ||
.wait() | ||
{ | ||
tracing::warn!("Error replying to query: {}", e); | ||
} | ||
} | ||
if let Err(e) = query | ||
.reply_sample( | ||
SampleBuilder::from(sample.clone()) | ||
.congestion_control(qos.congestion_control) | ||
.priority(qos.priority) | ||
.express(qos.is_express) | ||
.into(), | ||
) | ||
.await | ||
{ | ||
tracing::warn!("Error replying to query: {}", e); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
let mut cache: HashMap<OwnedKeyExpr, VecDeque<Sample>> = | ||
HashMap::with_capacity(history.resources_limit.unwrap_or(32)); | ||
let limit = history.resources_limit.unwrap_or(usize::MAX); | ||
|
||
loop { | ||
select!( | ||
// on publication received by the local subscriber, store it | ||
sample = sub_recv.recv_async() => { | ||
if let Ok(sample) = sample { | ||
let queryable_key_expr: KeyExpr<'_> = if let Some(prefix) = &queryable_prefix { | ||
prefix.join(&sample.key_expr()).unwrap().into() | ||
} else { | ||
sample.key_expr().clone() | ||
}; | ||
|
||
if let Some(queue) = cache.get_mut(queryable_key_expr.as_keyexpr()) { | ||
if queue.len() >= history.sample_depth { | ||
queue.pop_front(); | ||
} | ||
queue.push_back(sample); | ||
} else if cache.len() >= limit { | ||
tracing::error!("AdvancedCache on {}: resource_limit exceeded - can't cache publication for a new resource", | ||
pub_key_expr); | ||
} else { | ||
let mut queue: VecDeque<Sample> = VecDeque::new(); | ||
queue.push_back(sample); | ||
cache.insert(queryable_key_expr.into(), queue); | ||
} | ||
} | ||
}, | ||
|
||
// on query, reply with cache content | ||
query = quer_recv.recv_async() => { | ||
if let Ok(query) = query { | ||
let (start, end) = query.parameters().get("_sn").map(decode_range).unwrap_or((None, None)); | ||
let max = query.parameters().get("_max").and_then(|s| s.parse::<u32>().ok()); | ||
if !query.selector().key_expr().as_str().contains('*') { | ||
if let Some(queue) = cache.get(query.selector().key_expr().as_keyexpr()) { | ||
process_queue(queue, &query, start, end, max, &history.replies_qos).await; | ||
} | ||
} else { | ||
for (key_expr, queue) in cache.iter() { | ||
if query.selector().key_expr().intersects(key_expr.borrow()) { | ||
process_queue(queue, &query, start, end, max, &history.replies_qos).await; | ||
} else { | ||
for sample in queue.iter() { | ||
if sample_in_range(sample, start, end) { | ||
if let (Some(Ok(time_range)), Some(timestamp)) = | ||
(query.parameters().time_range(), sample.timestamp()) | ||
{ | ||
if !time_range | ||
.contains(timestamp.get_time().to_system_time()) | ||
{ | ||
continue; | ||
} | ||
} | ||
if let Err(e) = query | ||
.reply_sample( | ||
SampleBuilder::from(sample.clone()) | ||
.congestion_control( | ||
conf.history.replies_qos.congestion_control, | ||
) | ||
.priority(conf.history.replies_qos.priority) | ||
.express(conf.history.replies_qos.is_express) | ||
.into(), | ||
) | ||
.wait() | ||
{ | ||
tracing::warn!("Error replying to query: {}", e); | ||
} | ||
} | ||
} | ||
} | ||
}, | ||
|
||
// When stoptx is dropped, stop the task | ||
_ = stoprx.recv_async().fuse() => { | ||
return | ||
} else { | ||
tracing::error!("Unable to take AdvancedPublisher cache read lock"); | ||
} | ||
); | ||
} | ||
}); | ||
} | ||
}) | ||
.wait()?; | ||
|
||
let token = if conf.liveliness { | ||
Some( | ||
|
@@ -394,10 +309,21 @@ impl AdvancedCache { | |
}; | ||
|
||
Ok(AdvancedCache { | ||
_sub: sub, | ||
cache, | ||
max_samples: conf.history.max_samples, | ||
_queryable: queryable, | ||
_token: token, | ||
_stoptx: stoptx, | ||
}) | ||
} | ||
|
||
pub fn cache_sample(&self, sample: Sample) { | ||
if let Ok(mut queue) = self.cache.write() { | ||
if queue.len() >= self.max_samples { | ||
queue.pop_front(); | ||
} | ||
queue.push_back(sample); | ||
} else { | ||
tracing::error!("Unable to take AdvancedPublisher cache write lock"); | ||
} | ||
} | ||
} |
Oops, something went wrong.