Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pagination to Delete Commands #543

Merged
merged 10 commits into from
Sep 25, 2024
87 changes: 59 additions & 28 deletions panther_analysis_tool/backend/public_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional, Sequence
from urllib.parse import urlparse

from gql import Client as GraphQLClient
Expand Down Expand Up @@ -166,6 +166,10 @@ class PublicAPIClient(Client): # pylint: disable=too-many-public-methods
_requests: PublicAPIRequests
_gql_client: GraphQLClient

# backend's delete function can only handle 100 IDs at a time, due to DynamoDB restrictions
# https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ServiceQuotas.html#limits-expression-parameters
_DELETE_BATCH_SIZE = 100

def __init__(self, opts: PublicAPIClientOptions):
self._user_id = opts.user_id
self._requests = PublicAPIRequests()
Expand Down Expand Up @@ -329,23 +333,29 @@ def transpile_filters(
def delete_saved_queries(
self, params: DeleteSavedQueriesParams
) -> BackendResponse[DeleteSavedQueriesResponse]:
query = self._requests.delete_saved_queries()
delete_params = {
"input": {
"dryRun": params.dry_run,
"includeDetections": params.include_detections,
"names": params.names,
data: Dict = {"names": [], "detectionIDs": []}
for name_batch in _batched(params.names, self._DELETE_BATCH_SIZE):
gql_params = {
"input": {
"dryRun": params.dry_run,
"includeDetections": params.include_detections,
"names": name_batch,
}
}
}
res = self._execute(query, variable_values=delete_params)
res = self._execute(self._requests.delete_saved_queries(), variable_values=gql_params)

if res.errors:
raise BackendError(res.errors)
if res.errors:
for err in res.errors:
logging.error(err.message)

if res.data is None:
raise BackendError("empty data")
raise BackendError(res.errors)

if res.data is None:
raise BackendError("empty data")

data = res.data.get("deleteSavedQueriesByName", {})
query_data = res.data.get("deleteSavedQueriesByName", {})
for field in ("names", "detectionIDs"):
data[field] += query_data.get(field) or []

return BackendResponse(
status_code=200,
Expand All @@ -358,24 +368,29 @@ def delete_saved_queries(
def delete_detections(
self, params: DeleteDetectionsParams
) -> BackendResponse[DeleteDetectionsResponse]:
gql_params = {
"input": {
"dryRun": params.dry_run,
"includeSavedQueries": params.include_saved_queries,
"ids": params.ids,
data: Dict = {"ids": [], "savedQueryNames": []}
for id_batch in _batched(params.ids, self._DELETE_BATCH_SIZE):
gql_params = {
"input": {
"dryRun": params.dry_run,
"includeSavedQueries": params.include_saved_queries,
"ids": id_batch,
}
}
}
res = self._execute(self._requests.delete_detections_query(), gql_params)
if res.errors:
for err in res.errors:
logging.error(err.message)
res = self._execute(self._requests.delete_detections_query(), gql_params)

raise BackendError(res.errors)
if res.errors:
for err in res.errors:
logging.error(err.message)

if res.data is None:
raise BackendError("empty data")
raise BackendError(res.errors)

data = res.data.get("deleteDetections", {})
if res.data is None:
raise BackendError("empty data")

query_data = res.data.get("deleteDetections", {})
for field in ("ids", "savedQueryNames"):
data[field] += query_data.get(field) or []

return BackendResponse(
status_code=200,
Expand Down Expand Up @@ -693,3 +708,19 @@ def _build_api_url(host: str) -> str:
def _get_graphql_content_filepath(name: str) -> str:
work_dir = os.path.dirname(__file__)
return os.path.join(work_dir, "graphql", f"{name}.graphql")


def _batched(iterable: Sequence, size: int = 1) -> Generator[Sequence, None, None]:
"""Batch data from 'iterable' into chunks of length 'size'. The last batch may be shorter than 'size'.
Inspired by itertools.batched in Python version 3.12+.

Args:
iterable (any iterable): a sequence or other iterable to be batched
size (int, optional): the maximum size of each batch. default=1

Yields:
out (iterable): a batch of size 'size' or smaller
"""
length = len(iterable)
for idx in range(0, length, size):
yield iterable[idx : min(idx + size, length)]
48 changes: 48 additions & 0 deletions tests/unit/panther_analysis_tool/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import panther_analysis_tool.constants
from panther_analysis_tool import util as pat_utils
from panther_analysis_tool.backend.public_api_client import _batched
from panther_analysis_tool.util import convert_unicode


Expand Down Expand Up @@ -200,3 +201,50 @@ def test_is_policy(self):
for case in test_cases:
res = pat_utils.is_policy(case["analysis_type"])
self.assertEqual(case["expected"], res)


class TestBatched(unittest.TestCase):
def test_batched_with_remainder(self):
iterable = [1] * 12
n = 5
expected_batches = 3
modulo = 2 # Size of last batch

batches = list(_batched(iterable, n))
# Ensure we recieved the expected number of batches
self.assertEqual(len(batches), expected_batches)
# Confirm all but the last batch have the same size
for batch in batches[:-1]:
self.assertEqual(len(list(batch)), n)
# Confirm the last batch has the expected number of entries
self.assertEqual(len(list(batches[-1])), modulo)

def test_batched_with_no_remainder(self):
iterable = [1] * 100
n = 10
expected_batches = 10
modulo = 10 # Size of last batch

batches = list(_batched(iterable, n))
# Ensure we recieved the expected number of batches
self.assertEqual(len(batches), expected_batches)
# Confirm all but the last batch have the same size
for batch in batches[:-1]:
self.assertEqual(len(list(batch)), n)
# Confirm the last batch has the expected number of entries
self.assertEqual(len(list(batches[-1])), modulo)

def test_batched_with_no_full_batches(self):
iterable = [1] * 3
n = 5
expected_batches = 1
modulo = 3 # Size of last batch

batches = list(_batched(iterable, n))
# Ensure we recieved the expected number of batches
self.assertEqual(len(batches), expected_batches)
# Confirm all but the last batch have the same size
for batch in batches[:-1]:
self.assertEqual(len(list(batch)), n)
# Confirm the last batch has the expected number of entries
self.assertEqual(len(list(batches[-1])), modulo)
Loading