Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support manual cache eviction #1

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
39 changes: 38 additions & 1 deletion aws_secretsmanager_agent/src/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,45 @@ impl CacheManager {
}
}
}
}

/// Evict a secret from the cache.
///
/// # Arguments
///
/// * `secret_id` - The name of the secret to evict.
/// * `version` - The version of the secret to evict.
/// * `label` - The label of the secret to evict.
///
/// # Returns
///
/// * `Ok(String)` - Success message.
/// * `Err((u16, String))` - The error code and message.
///
/// # Errors
///
/// * `HttpError` - The error returned from the SDK.
///
/// # Example
///
/// ```
/// let cache_manager = CacheManager::new().await.unwrap();
/// let message = cache_manager.evict("my-secret", None, None).unwrap();
/// ```
pub async fn evict(
&self,
secret_id: &str,
version: Option<&str>,
label: Option<&str>
) -> Result<String, HttpError> {
match self.0.remove_secret_value(secret_id, version, label).await {
Ok(_) => Ok("Secret successfully evicted".to_string()),
Err(e) => {
error!("Failed to evict secret {}: {:?}", secret_id, e);
Err(HttpError(500, err_response("EvictionError", "Failed to evict secret")))
}
}
}
}
/// Private helper to format in internal service error response.
#[doc(hidden)]
fn int_err() -> HttpError {
Expand Down
4 changes: 2 additions & 2 deletions aws_secretsmanager_agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,9 @@ mod tests {

// Verify requests using the wrong verbs fail with 405.
#[tokio::test]
async fn get_only() {
async fn get_and_post_only() {
for verb in [
"POST", "PUT", "PATCH", "DELETE", "HEAD", "CONNECT", "OPTIONS", "TRACE",
"PUT", "PATCH", "DELETE", "HEAD", "CONNECT", "OPTIONS", "TRACE",
] {
let (status, _) =
run_requests_with_verb(vec![(verb, "/secretsmanager/get?secretId=MyTest")])
Expand Down
60 changes: 34 additions & 26 deletions aws_secretsmanager_agent/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ pub struct Server {

/// Handle incoming HTTP requests.
///
/// Implements the HTTP handler. Each incomming request is handled in its own
/// Implements the HTTP handler. Each incoming request is handled in its own
/// thread.
impl Server {
/// Create a server instance.
///
/// # Arguments
///
/// * `listener` - The TcpListener to use to accept incomming requests.
/// * `listener` - The TcpListener to use to accept incoming requests.
/// * `cfg` - The config object to use for options such header names.
///
/// # Returns
Expand Down Expand Up @@ -87,11 +87,11 @@ impl Server {
Ok(())
}

/// Private helper to process the incomming request body and format a response.
/// Private helper to process the incoming request body and format a response.
///
/// # Arguments
///
/// * `req` - The incomming HTTP request.
/// * `req` - The incoming HTTP request.
/// * `count` - The number of concurrent requets being handled.
///
/// # Returns
Expand All @@ -118,11 +118,11 @@ impl Server {
}
}

/// Parse an incomming request and provide the response data.
/// Parse an incoming request and provide the response data.
///
/// # Arguments
///
/// * `req` - The incomming HTTP request.
/// * `req` - The incoming HTTP request.
/// * `count` - The number of concurrent requets being handled.
///
/// # Returns
Expand All @@ -137,13 +137,11 @@ impl Server {
) -> Result<String, HttpError> {
self.validate_max_conn(req, count)?; // Verify connection limits are not exceeded
self.validate_token(req)?; // Check for a valid SSRF token
self.validate_method(req)?; // Allow only GET requests
self.validate_method(req)?; // Allow only GET and POST requests

match req.uri().path() {
"/ping" => Ok("healthy".into()), // Standard health check

// Lambda extension style query
"/secretsmanager/get" => {
match (req.method(), req.uri().path()) {
(&Method::GET, "/ping") => Ok("healthy".into()), // Standard health check
(&Method::GET, "/secretsmanager/get") => { // Lambda extension style query
let qry = GSVQuery::try_from_query(&req.uri().to_string())?;
Ok(self
.cache_mgr
Expand All @@ -154,9 +152,8 @@ impl Server {
)
.await?)
}

// Path style request
path if path.starts_with(self.path_prefix.as_str()) => {
(&Method::GET, path) if path.starts_with(self.path_prefix.as_str()) => {
let qry = GSVQuery::try_from_path_query(&req.uri().to_string(), &self.path_prefix)?;
Ok(self
.cache_mgr
Expand All @@ -167,17 +164,29 @@ impl Server {
)
.await?)
}
(&Method::POST, "/secretsmanager/evict") => {
let qry = GSVQuery::try_from_query(&req.uri().to_string())?;
Ok(self
.cache_mgr
.evict(
&qry.secret_id,
qry.version_id.as_deref(),
qry.version_stage.as_deref(),
)
.await?)
}
_ => Err(HttpError(404, "Not found".into())),
}
}

/// Verify the incomming request does not exceed the maximum connection limit.

/// Verify the incoming request does not exceed the maximum connection limit.
///
/// The limit is not enforced for ping/health checks.
///
/// # Arguments
///
/// * `req` - The incomming HTTP request.
/// * `req` - The incoming HTTP request.
/// * `count` - The number of concurrent requets being handled.
///
/// # Returns
Expand Down Expand Up @@ -209,7 +218,7 @@ impl Server {
///
/// # Arguments
///
/// * `req` - The incomming HTTP request.
/// * `req` - The incoming HTTP request.
///
/// # Returns
///
Expand Down Expand Up @@ -241,22 +250,21 @@ impl Server {
Err(HttpError(403, "Bad Token".into()))
}

/// Verify the request is using the GET HTTP verb.
/// Verify the request is using the GET or POST HTTP verb.
///
/// # Arguments
///
/// * `req` - The incomming HTTP request.
/// * `req` - The incoming HTTP request.
///
/// # Returns
///
/// * `Ok(())` - If the GET verb/method is use.
/// * `Err((u16, String))` - A 405 error codde and message when GET is not used.
/// * `Ok(())` - If the GET or POST verb/method is use.
/// * `Err((u16, String))` - A 405 error codde and message when GET or POST is not used.
#[doc(hidden)]
fn validate_method(&self, req: &Request<IncomingBody>) -> Result<(), HttpError> {
if *req.method() == Method::GET {
return Ok(());
match *req.method() {
Method::GET | Method::POST => Ok(()),
_ => Err(HttpError(405, "Method not allowed".into())),
}

Err(HttpError(405, "Not allowed".into()))
}
}
}
18 changes: 18 additions & 0 deletions aws_secretsmanager_caching/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,24 @@ impl SecretsManagerCachingClient {

Ok(false)
}

/// Removes a secret in the cache, forcing a refresh on the next retrieval.
///
/// # Arguments
///
/// * `secret_id` - The ARN or name of the secret to remove.
/// * `version_id` - The version id of the secret version to remove.
/// * `version_stage` - The staging label of the version of the secret to remove.
pub async fn remove_secret_value(
&self,
secret_id: &str,
version_id: Option<&str>,
version_stage: Option<&str>,
) -> Result<(), Box<dyn Error>> {
let mut write_lock = self.store.write().await;
write_lock.remove_secret_value(secret_id, version_id, version_stage)?;
Ok(())
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,17 @@ impl<K: Hash + Eq, V> Cache<K, V> {
{
self.entries.get(key)
}
}

/// Removes a key-value pair from the cache.
/// Returns the value if the key was present in the cache, None otherwise.
pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
Q: ?Sized + Hash + Eq,
K: Borrow<Q>,
{
self.entries.remove(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -113,4 +122,19 @@ mod tests {
let items: Vec<usize> = cache.entries.iter().map(|t| (*t.1)).collect();
assert_eq!(items, [2]);
}

#[test]
fn remove_evicts_entry() {
let mut cache = TestIntCache::new(NonZeroUsize::new(4).unwrap());

cache.insert("test1".to_string(), 1);
cache.insert("test2".to_string(), 2);

assert_eq!(cache.remove("test1"), Some(1));
assert_eq!(cache.len(), 1);
assert_eq!(cache.get("test1"), None);
assert_eq!(cache.get("test2"), Some(&2));

assert_eq!(cache.remove("non_existent"), None);
}
}
43 changes: 43 additions & 0 deletions aws_secretsmanager_caching/src/secret_store/memory_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,24 @@ impl SecretStore for MemoryStore {

Ok(())
}

fn remove_secret_value(
&mut self,
secret_id: &str,
version_id: Option<&str>,
version_stage: Option<&str>,
) -> Result<(), SecretStoreError> {
let key = Key {
secret_id: secret_id.to_string(),
version_id: version_id.map(String::from),
version_stage: version_stage.map(String::from),
};
if self.gsv_cache.remove(&key).is_none() {
Err(SecretStoreError::ResourceNotFound)
} else {
Ok(())
}
}
}

/// Write the secret value to the store
Expand Down Expand Up @@ -275,4 +293,29 @@ mod tests {
Err(e) => panic!("Unexpected error: {}", e),
}
}

#[test]
fn memory_store_remove_secret_value() {
let mut store = MemoryStore::default();

store_secret(&mut store, None, None, None);

// Verify the secret exists
assert!(store.get_secret_value(NAME, None, None).is_ok());

// Remove the secret
assert!(store.remove_secret_value(NAME, None, None).is_ok());

// Verify the secret no longer exists
match store.get_secret_value(NAME, None, None) {
Err(SecretStoreError::ResourceNotFound) => (),
_ => panic!("Expected ResourceNotFound error"),
}

// Attempt to remove a non-existent secret
match store.remove_secret_value("non_existent", None, None) {
Err(SecretStoreError::ResourceNotFound) => (),
_ => panic!("Expected ResourceNotFound error"),
}
}
}
8 changes: 8 additions & 0 deletions aws_secretsmanager_caching/src/secret_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ pub trait SecretStore: Debug + Send + Sync {
version_stage: Option<String>,
data: GetSecretValueOutputDef,
) -> Result<(), SecretStoreError>;

/// Remove the secret value from the store
fn remove_secret_value(
&mut self,
secret_id: &str,
version_id: Option<&str>,
version_stage: Option<&str>,
) -> Result<(), SecretStoreError>;
}

/// All possible error types
Expand Down