Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: stub the VPC config cache and invalidation code #10073

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion proxy/src/cache/project_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ use super::{Cache, Cached};
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::AuthSecret;
use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};

#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
Expand Down Expand Up @@ -51,6 +54,8 @@ impl<T> From<T> for Entry<T> {
struct EndpointInfo {
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
block_public_or_vpc_access: Option<Entry<(bool, bool)>>,
allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
}

impl EndpointInfo {
Expand Down Expand Up @@ -92,6 +97,51 @@ impl EndpointInfo {
}
None
}

pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<String>>, bool)> {
if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
if valid_since < allowed_vpc_endpoint_ids.created_at {
return Some((
allowed_vpc_endpoint_ids.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
allowed_vpc_endpoint_ids.created_at,
),
));
}
}
None
}

pub(crate) fn get_block_public_or_vpc_access(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<((bool, bool), bool)> {
if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
if valid_since < block_public_or_vpc_access.created_at {
return Some((
block_public_or_vpc_access.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
block_public_or_vpc_access.created_at,
),
));
}
}
None
}

pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
self.block_public_or_vpc_access = None;
}
pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
self.allowed_vpc_endpoint_ids = None;
}
pub(crate) fn invalidate_allowed_ips(&mut self) {
self.allowed_ips = None;
}
Expand All @@ -111,6 +161,8 @@ pub struct ProjectInfoCacheImpl {
cache: DashMap<EndpointIdInt, EndpointInfo>,

project2ep: DashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
// FIXME(stefan): we need a way to GC the account2ep map.
account2ep: DashMap<AccountIdInt, HashSet<EndpointIdInt>>,
config: ProjectInfoCacheOptions,

start_time: Instant,
Expand All @@ -120,6 +172,59 @@ pub struct ProjectInfoCacheImpl {

#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
info!(
"invalidating allowed vpc endpoint ids for projects `{}`",
project_ids.join(", ")
);
for project_id in project_ids {
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
}
}
}

fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
info!(
"invalidating allowed vpc endpoint ids for org `{}`",
account_id
);
let endpoints = self
.account2ep
.get(&account_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
}
}

fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
info!(
"invalidating block public or vpc access for project `{}`",
project_id
);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}

fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating allowed ips for project `{}`", project_id);
let endpoints = self
Expand Down Expand Up @@ -178,6 +283,7 @@ impl ProjectInfoCacheImpl {
Self {
cache: DashMap::new(),
project2ep: DashMap::new(),
account2ep: DashMap::new(),
config,
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
start_time: Instant::now(),
Expand Down Expand Up @@ -226,6 +332,7 @@ impl ProjectInfoCacheImpl {
}
Some(Cached::new_uncached(value))
}

pub(crate) fn insert_role_secret(
&self,
project_id: ProjectIdInt,
Expand Down Expand Up @@ -256,6 +363,16 @@ impl ProjectInfoCacheImpl {
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
}
pub(crate) fn insert_vpc_allowed_endpoint_ids(&self, account_id: AccountIdInt, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, vpc_allowed_endpoint_ids: HashSet<EndpointIdInt>) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_account2endpoint(account_id, endpoint_id);
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().vpc_allowed_endpoint_ids = Some(vpc_allowed_endpoint_ids);
}
}
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
endpoints.insert(endpoint_id);
Expand All @@ -264,6 +381,13 @@ impl ProjectInfoCacheImpl {
.insert(project_id, HashSet::from([endpoint_id]));
}
}
fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
endpoints.insert(endpoint_id);
} else {
self.account2ep.insert(account_id, HashSet::from([endpoint_id]));
}
}
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
let mut valid_since = Instant::now() - self.config.ttl;
// Only ignore cache if ttl is disabled.
Expand Down
2 changes: 2 additions & 0 deletions proxy/src/control_plane/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ pub(crate) struct GetEndpointAccessControl {
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<EndpointIdInt>>,
pub(crate) block_public_connections: Option<bool>,
pub(crate) block_vpc_connections: Option<bool>,
}

// Manually implement debug to omit sensitive info.
Expand Down
22 changes: 21 additions & 1 deletion proxy/src/intern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::OnceLock;
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
use rustc_hash::FxHasher;

use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::types::{AccountId, BranchId, EndpointId, ProjectId, RoleName};

pub trait InternId: Sized + 'static {
fn get_interner() -> &'static StringInterner<Self>;
Expand Down Expand Up @@ -206,6 +206,26 @@ impl From<ProjectId> for ProjectIdInt {
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct AccountIdTag;
impl InternId for AccountIdTag {
fn get_interner() -> &'static StringInterner<Self> {
static ROLE_NAMES: OnceLock<StringInterner<AccountIdTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type AccountIdInt = InternedString<AccountIdTag>;
impl From<&AccountId> for AccountIdInt {
fn from(value: &AccountId) -> Self {
AccountIdTag::get_interner().get_or_intern(value)
}
}
impl From<AccountId> for AccountIdInt {
fn from(value: AccountId) -> Self {
AccountIdTag::get_interner().get_or_intern(&value)
}
}

#[cfg(test)]
mod tests {
use std::sync::OnceLock;
Expand Down
3 changes: 3 additions & 0 deletions proxy/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,9 @@ pub enum RedisEventsCount {
CancelSession,
PasswordUpdate,
AllowedIpsUpdate,
AllowedVpcEndpointIdsUpdateForProjects,
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
BlockPublicOrVpcAccessUpdate,
}

pub struct ThreadPoolWorkers(usize);
Expand Down
81 changes: 79 additions & 2 deletions proxy/src/redis/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use tracing::Instrument;

Expand Down Expand Up @@ -39,6 +39,27 @@ pub(crate) enum Notification {
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/allowed_vpc_endpoint_ids_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointIdsUpdateForProjects {
allowed_vpc_endpoint_ids_update_for_projects: AllowedVpcEndpointIdsUpdateForProjects,
},
#[serde(
rename = "/allowed_vpc_endpoint_ids_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg {
allowed_vpc_endpoint_ids_update_for_org: AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
},
#[serde(
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdate {
block_public_or_vpc_access_update: BlockPublicOrVpcAccessUpdate,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
Expand All @@ -51,6 +72,22 @@ pub(crate) enum Notification {
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}

#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointIdsUpdateForProjects {
project_ids: Vec<ProjectIdInt>,
}

#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointIdsUpdateForAllProjectsInOrg {
account_id: AccountIdInt,
}

#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdate {
project_id: ProjectIdInt,
}

#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
Expand Down Expand Up @@ -164,7 +201,11 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::AllowedVpcEndpointIdsUpdateForProjects { .. }
| Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { .. }
| Notification::BlockPublicOrVpcAccessUpdate { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
Expand All @@ -176,6 +217,27 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.proxy
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
} else if matches!(
msg,
Notification::AllowedVpcEndpointIdsUpdateForProjects { .. }
) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
} else if matches!(
msg,
Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { .. }
) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
}
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
Expand All @@ -197,6 +259,21 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
Notification::AllowedIpsUpdate { allowed_ips_update } => {
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
}
Notification::AllowedVpcEndpointIdsUpdateForProjects {
allowed_vpc_endpoint_ids_update_for_projects,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_projects(
allowed_vpc_endpoint_ids_update_for_projects.project_ids,
),
Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg {
allowed_vpc_endpoint_ids_update_for_org,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_org(
allowed_vpc_endpoint_ids_update_for_org.account_id,
),
Notification::BlockPublicOrVpcAccessUpdate {
block_public_or_vpc_access_update,
} => cache.invalidate_block_public_or_vpc_access_for_project(
block_public_or_vpc_access_update.project_id,
),
Notification::PasswordUpdate { password_update } => cache
.invalidate_role_secret_for_project(
password_update.project_id,
Expand Down
2 changes: 2 additions & 0 deletions proxy/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ smol_str_wrapper!(EndpointId);
smol_str_wrapper!(BranchId);
// 90% of project strings are 23 characters or less.
smol_str_wrapper!(ProjectId);
// 90% of account strings are 23 characters or less.
smol_str_wrapper!(AccountId);

// will usually equal endpoint ID
smol_str_wrapper!(EndpointCacheKey);
Expand Down
Loading