Skip to content

Commit

Permalink
RDS: Refactor Tagging Endpoints
Browse files Browse the repository at this point in the history
Replace the large/duplicated if statements with polymorphism.

Returning an empty array when a resource cannot be found is not how AWS
handles things, but leaving for now as there are tests that rely on this
(incorrect) behavior.  Will address in a later commit.
  • Loading branch information
bpandola committed Aug 3, 2024
1 parent 64066c8 commit b37593f
Showing 1 changed file with 36 additions and 114 deletions.
150 changes: 36 additions & 114 deletions moto/rds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,9 +1649,21 @@ def __init__(self, region_name: str, account_id: str):
self.subnet_groups: Dict[str, DBSubnetGroup] = {}
self._db_cluster_options: Optional[List[Dict[str, Any]]] = None
self.db_proxies: Dict[str, DBProxy] = OrderedDict()

def reset(self) -> None:
super().reset()
self.resource_map = {
DBCluster: self.clusters,
DBClusterParameterGroup: self.db_cluster_parameter_groups,
DBClusterSnapshot: self.cluster_snapshots,
DBInstance: self.databases,
DBParameterGroup: self.db_parameter_groups,
DBProxy: self.db_proxies,
DBSecurityGroup: self.security_groups,
DBSnapshot: self.database_snapshots,
DBSubnetGroup: self.subnet_groups,
EventSubscription: self.event_subscriptions,
ExportTask: self.export_tasks,
GlobalCluster: self.global_clusters,
OptionGroup: self.option_groups,
}

@lru_cache()
def db_cluster_options(self, engine) -> List[Dict[str, Any]]: # type: ignore
Expand Down Expand Up @@ -2533,92 +2545,33 @@ def describe_event_subscriptions(
raise SubscriptionNotFoundError(subscription_name)
return self.event_subscriptions.values()

def _find_resource(self, resource_type: str, resource_name: str) -> Any:
for resource_class, resources in self.resource_map.items():
if resource_type == getattr(resource_class, "resource_type", ""):
if resource_name in resources: # type: ignore
return resources[resource_name] # type: ignore

def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]:
if self.arn_regex.match(arn):
arn_breakdown = arn.split(":")
resource_type = arn_breakdown[len(arn_breakdown) - 2]
resource_name = arn_breakdown[len(arn_breakdown) - 1]
if resource_type == "db": # Database
if resource_name in self.databases:
return self.databases[resource_name].get_tags()
elif resource_type == "cluster": # Cluster
if resource_name in self.clusters:
return self.clusters[resource_name].get_tags()
elif resource_type == "es": # Event Subscription
if resource_name in self.event_subscriptions:
return self.event_subscriptions[resource_name].get_tags()
elif resource_type == "og": # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].get_tags()
elif resource_type == "pg": # Parameter Group
if resource_name in self.db_parameter_groups:
return self.db_parameter_groups[resource_name].get_tags()
elif resource_type == "ri": # Reserved DB instance
# TODO: Complete call to tags on resource type Reserved DB
# instance
return []
elif resource_type == "secgrp": # DB security group
if resource_name in self.security_groups:
return self.security_groups[resource_name].get_tags()
elif resource_type == "snapshot": # DB Snapshot
if resource_name in self.database_snapshots:
return self.database_snapshots[resource_name].get_tags()
elif resource_type == "cluster-snapshot": # DB Cluster Snapshot
if resource_name in self.cluster_snapshots:
return self.cluster_snapshots[resource_name].get_tags()
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].get_tags()
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
return self.db_proxies[resource_name].get_tags()
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
)
return []
resource = self._find_resource(resource_type, resource_name)
if resource:
return resource.get_tags()
return []
raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}")

def remove_tags_from_resource(self, arn: str, tag_keys: List[str]) -> None:
if self.arn_regex.match(arn):
arn_breakdown = arn.split(":")
resource_type = arn_breakdown[len(arn_breakdown) - 2]
resource_name = arn_breakdown[len(arn_breakdown) - 1]
if resource_type == "db": # Database
if resource_name in self.databases:
self.databases[resource_name].remove_tags(tag_keys)
elif resource_type == "es": # Event Subscription
if resource_name in self.event_subscriptions:
self.event_subscriptions[resource_name].remove_tags(tag_keys)
elif resource_type == "og": # Option Group
if resource_name in self.option_groups:
self.option_groups[resource_name].remove_tags(tag_keys)
elif resource_type == "pg": # Parameter Group
if resource_name in self.db_parameter_groups:
self.db_parameter_groups[resource_name].remove_tags(tag_keys)
elif resource_type == "ri": # Reserved DB instance
return None
elif resource_type == "secgrp": # DB security group
if resource_name in self.security_groups:
self.security_groups[resource_name].remove_tags(tag_keys)
elif resource_type == "snapshot": # DB Snapshot
if resource_name in self.database_snapshots:
self.database_snapshots[resource_name].remove_tags(tag_keys)
elif resource_type == "cluster":
if resource_name in self.clusters:
self.clusters[resource_name].remove_tags(tag_keys)
elif resource_type == "cluster-snapshot": # DB Cluster Snapshot
if resource_name in self.cluster_snapshots:
self.cluster_snapshots[resource_name].remove_tags(tag_keys)
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
self.subnet_groups[resource_name].remove_tags(tag_keys)
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
self.db_proxies[resource_name].remove_tags(tag_keys)
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
)
resource = self._find_resource(resource_type, resource_name)
if resource:
resource.remove_tags(tag_keys)
return
raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}")

def add_tags_to_resource( # type: ignore[return]
self, arn: str, tags: List[Dict[str, str]]
Expand All @@ -2627,42 +2580,11 @@ def add_tags_to_resource( # type: ignore[return]
arn_breakdown = arn.split(":")
resource_type = arn_breakdown[-2]
resource_name = arn_breakdown[-1]
if resource_type == "db": # Database
if resource_name in self.databases:
return self.databases[resource_name].add_tags(tags)
elif resource_type == "es": # Event Subscription
if resource_name in self.event_subscriptions:
return self.event_subscriptions[resource_name].add_tags(tags)
elif resource_type == "og": # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].add_tags(tags)
elif resource_type == "pg": # Parameter Group
if resource_name in self.db_parameter_groups:
return self.db_parameter_groups[resource_name].add_tags(tags)
elif resource_type == "ri": # Reserved DB instance
return []
elif resource_type == "secgrp": # DB security group
if resource_name in self.security_groups:
return self.security_groups[resource_name].add_tags(tags)
elif resource_type == "snapshot": # DB Snapshot
if resource_name in self.database_snapshots:
return self.database_snapshots[resource_name].add_tags(tags)
elif resource_type == "cluster":
if resource_name in self.clusters:
return self.clusters[resource_name].add_tags(tags)
elif resource_type == "cluster-snapshot": # DB Cluster Snapshot
if resource_name in self.cluster_snapshots:
return self.cluster_snapshots[resource_name].add_tags(tags)
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].add_tags(tags)
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
return self.db_proxies[resource_name].add_tags(tags)
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
)
resource = self._find_resource(resource_type, resource_name)
if resource:
return resource.add_tags(tags)
return []
raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}")

@staticmethod
def _filter_resources(resources: Any, filters: Any, resource_class: Any) -> Any: # type: ignore[misc]
Expand Down

0 comments on commit b37593f

Please sign in to comment.