diff --git a/panther_analysis_tool/backend/public_api_client.py b/panther_analysis_tool/backend/public_api_client.py index 7fd4c96c..f4d41dd9 100644 --- a/panther_analysis_tool/backend/public_api_client.py +++ b/panther_analysis_tool/backend/public_api_client.py @@ -24,7 +24,7 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional, Sequence from urllib.parse import urlparse from gql import Client as GraphQLClient @@ -166,6 +166,10 @@ class PublicAPIClient(Client): # pylint: disable=too-many-public-methods _requests: PublicAPIRequests _gql_client: GraphQLClient + # backend's delete function can only handle 100 IDs at a time, due to DynamoDB restrictions + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ServiceQuotas.html#limits-expression-parameters + _DELETE_BATCH_SIZE = 100 + def __init__(self, opts: PublicAPIClientOptions): self._user_id = opts.user_id self._requests = PublicAPIRequests() @@ -329,23 +333,29 @@ def transpile_filters( def delete_saved_queries( self, params: DeleteSavedQueriesParams ) -> BackendResponse[DeleteSavedQueriesResponse]: - query = self._requests.delete_saved_queries() - delete_params = { - "input": { - "dryRun": params.dry_run, - "includeDetections": params.include_detections, - "names": params.names, + data: Dict = {"names": [], "detectionIDs": []} + for name_batch in _batched(params.names, self._DELETE_BATCH_SIZE): + gql_params = { + "input": { + "dryRun": params.dry_run, + "includeDetections": params.include_detections, + "names": name_batch, + } } - } - res = self._execute(query, variable_values=delete_params) + res = self._execute(self._requests.delete_saved_queries(), variable_values=gql_params) - if res.errors: - raise BackendError(res.errors) + if res.errors: + for err in res.errors: + logging.error(err.message) - if res.data is None: - raise BackendError("empty data") + raise BackendError(res.errors) + + if res.data is None: + raise BackendError("empty data") - data = res.data.get("deleteSavedQueriesByName", {}) + query_data = res.data.get("deleteSavedQueriesByName", {}) + for field in ("names", "detectionIDs"): + data[field] += query_data.get(field) or [] return BackendResponse( status_code=200, @@ -358,24 +368,29 @@ def delete_saved_queries( def delete_detections( self, params: DeleteDetectionsParams ) -> BackendResponse[DeleteDetectionsResponse]: - gql_params = { - "input": { - "dryRun": params.dry_run, - "includeSavedQueries": params.include_saved_queries, - "ids": params.ids, + data: Dict = {"ids": [], "savedQueryNames": []} + for id_batch in _batched(params.ids, self._DELETE_BATCH_SIZE): + gql_params = { + "input": { + "dryRun": params.dry_run, + "includeSavedQueries": params.include_saved_queries, + "ids": id_batch, + } } - } - res = self._execute(self._requests.delete_detections_query(), gql_params) - if res.errors: - for err in res.errors: - logging.error(err.message) + res = self._execute(self._requests.delete_detections_query(), gql_params) - raise BackendError(res.errors) + if res.errors: + for err in res.errors: + logging.error(err.message) - if res.data is None: - raise BackendError("empty data") + raise BackendError(res.errors) - data = res.data.get("deleteDetections", {}) + if res.data is None: + raise BackendError("empty data") + + query_data = res.data.get("deleteDetections", {}) + for field in ("ids", "savedQueryNames"): + data[field] += query_data.get(field) or [] return BackendResponse( status_code=200, @@ -693,3 +708,19 @@ def _build_api_url(host: str) -> str: def _get_graphql_content_filepath(name: str) -> str: work_dir = os.path.dirname(__file__) return os.path.join(work_dir, "graphql", f"{name}.graphql") + + +def _batched(iterable: Sequence, size: int = 1) -> Generator[Sequence, None, None]: + """Batch data from 'iterable' into chunks of length 'size'. The last batch may be shorter than 'size'. + Inspired by itertools.batched in Python version 3.12+. + + Args: + iterable (any iterable): a sequence or other iterable to be batched + size (int, optional): the maximum size of each batch. default=1 + + Yields: + out (iterable): a batch of size 'size' or smaller + """ + length = len(iterable) + for idx in range(0, length, size): + yield iterable[idx : min(idx + size, length)] diff --git a/tests/unit/panther_analysis_tool/test_util.py b/tests/unit/panther_analysis_tool/test_util.py index 93fad87a..3edaec20 100644 --- a/tests/unit/panther_analysis_tool/test_util.py +++ b/tests/unit/panther_analysis_tool/test_util.py @@ -5,6 +5,7 @@ import panther_analysis_tool.constants from panther_analysis_tool import util as pat_utils +from panther_analysis_tool.backend.public_api_client import _batched from panther_analysis_tool.util import convert_unicode @@ -200,3 +201,50 @@ def test_is_policy(self): for case in test_cases: res = pat_utils.is_policy(case["analysis_type"]) self.assertEqual(case["expected"], res) + + +class TestBatched(unittest.TestCase): + def test_batched_with_remainder(self): + iterable = [1] * 12 + n = 5 + expected_batches = 3 + modulo = 2 # Size of last batch + + batches = list(_batched(iterable, n)) + # Ensure we recieved the expected number of batches + self.assertEqual(len(batches), expected_batches) + # Confirm all but the last batch have the same size + for batch in batches[:-1]: + self.assertEqual(len(list(batch)), n) + # Confirm the last batch has the expected number of entries + self.assertEqual(len(list(batches[-1])), modulo) + + def test_batched_with_no_remainder(self): + iterable = [1] * 100 + n = 10 + expected_batches = 10 + modulo = 10 # Size of last batch + + batches = list(_batched(iterable, n)) + # Ensure we recieved the expected number of batches + self.assertEqual(len(batches), expected_batches) + # Confirm all but the last batch have the same size + for batch in batches[:-1]: + self.assertEqual(len(list(batch)), n) + # Confirm the last batch has the expected number of entries + self.assertEqual(len(list(batches[-1])), modulo) + + def test_batched_with_no_full_batches(self): + iterable = [1] * 3 + n = 5 + expected_batches = 1 + modulo = 3 # Size of last batch + + batches = list(_batched(iterable, n)) + # Ensure we recieved the expected number of batches + self.assertEqual(len(batches), expected_batches) + # Confirm all but the last batch have the same size + for batch in batches[:-1]: + self.assertEqual(len(list(batch)), n) + # Confirm the last batch has the expected number of entries + self.assertEqual(len(list(batches[-1])), modulo)