Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable query caching #756

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion conda/conda-reqs-pip.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
azure-mgmt-resourcegraph>=8.0.0
azure-monitor-query>=1.0.0, <=2.0.0
dataclasses-json >= 0.5.7
# KqlmagicCustom[jupyter-basic,auth_code_clipboard]>=0.1.114.post22
mo-sql-parsing>=8, <9.0.0
nbformat>=5.9.2
nest_asyncio>=1.4.0
passivetotal>=2.5.3
sumologic-sdk>=0.1.11
splunk-sdk>=1.6.0
sumologic-sdk>=0.1.11
86 changes: 86 additions & 0 deletions msticpy/common/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Common methods to handle caching."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

from ...datamodel.result import QueryResult
from ..utility.ipython import is_ipython
from . import cell
from . import file as cache_file
from .codec import compute_digest

if TYPE_CHECKING:
import pandas as pd

LOGGER: logging.Logger = logging.getLogger(__name__)


def write_cache( # noqa: PLR0913
data: pd.DataFrame,
search_params: dict[str, Any],
query: str,
name: str,
cache_path: str | None = None,
*,
display: bool = False,
) -> None:
"""Cache query result in a cell or a parquet file."""
cache_digest: str = compute_digest(search_params)
cache: QueryResult = QueryResult(
name=name,
query=query,
raw_results=data,
arguments=search_params,
)
if is_ipython() and display:
cell.write_cache(
cache,
name,
cache_digest,
)
if cache_path:
LOGGER.info("Writing cache to %s", cache_path)
cache_file.write_cache(
data=cache,
file_name=f"{name}_{cache_digest}",
export_folder=cache_path,
)


def read_cache(
search_params: dict[str, Any],
cache_path: str | None,
name: str | None = None,
) -> QueryResult:
"""Retrieve result from cache in a cell or a archive file."""
if not cache_path:
error_msg: str = "Cache not provided."
raise ValueError(error_msg)
cache_digest: str = compute_digest(search_params)
if is_ipython():
try:
return cell.read_cache(
name or cache_digest,
cache_digest,
cache_path,
)
except ValueError:
pass
try:
cache: QueryResult = cache_file.read_cache(
f"{name}_{cache_digest}",
cache_path,
)
except FileNotFoundError as exc:
error_msg = "Could not read from cache."
raise ValueError(error_msg) from exc
if is_ipython():
# Writing cache to cell since it has not been found.
cell.write_cache(
cache,
name or cache_digest,
cache_digest,
)
return cache
108 changes: 108 additions & 0 deletions msticpy/common/cache/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Handle caching in Notebook cell."""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

import nbformat
from IPython.display import display

from ...datamodel.result import QueryResult
from .codec import decode_base64_as_pickle, encode_as_base64_pickle

LOGGER: logging.Logger = logging.getLogger(__name__)


def write_cache(
data: QueryResult,
name: str,
digest: str,
) -> None:
"""Cache content in cell."""
cache: str = encode_as_base64_pickle(data)
metadata: dict[str, Any] = {
"data": cache,
"hash": digest,
}
if isinstance(data, QueryResult):
metadata.update(
{
"name": name,
"query": data.query,
"arguments": data.arguments,
"timestamp": data.timestamp,
},
)
LOGGER.debug("Data %s written to Notebook cache", name)
display(
data.raw_results,
metadata=metadata,
exclude=["text/plain"],
)


def get_cache_item(path: Path, name: str, digest: str) -> dict[str, Any]:
"""
Get named object from cache.

Parameters
----------
path : Path
Path to notebook
name : str
name of the cached object to search
digest : str
Hash of the cached object to search

Returns
-------
dict[str, Any]
Cached object.
"""
if not path.exists():
error_msg: str = "Notebook not found"
raise FileNotFoundError(error_msg)

notebook: nbformat.NotebookNode = nbformat.reads(
path.read_text(encoding="utf-8"),
as_version=nbformat.current_nbformat,
)

try:
cache: dict[str, Any] = next(
iter(
[
(output.get("metadata", {}) or {})
for cell in (notebook.cells or [])
for output in (cell.get("outputs", []) or [])
if output.get("metadata", {}).get("hash") == digest
and output.get("metadata", {}).get("name") == name
],
),
)
except StopIteration:
LOGGER.debug("%s not found in %s cache...", digest, path)
cache = {}

return cache


def read_cache(name: str, digest: str, nb_path: str) -> QueryResult:
"""Read cache content from file."""
if not nb_path:
error_msg: str = "Argument nb_path must be defined."
raise ValueError(error_msg)

notebook_fp: Path = Path(nb_path).absolute()

if not notebook_fp.exists():
error_msg = "Notebook not found"
raise FileNotFoundError(error_msg)

cache: dict[str, Any] = get_cache_item(path=notebook_fp, name=name, digest=digest)
if cache and (data := cache.get("data")):
LOGGER.debug("Digest %s found in cache...", digest)
return decode_base64_as_pickle(data)
error_msg = f"Cache {digest} not found"
raise ValueError(error_msg)
40 changes: 40 additions & 0 deletions msticpy/common/cache/codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Functions to encode/decode cached objects."""

import base64
import json
import logging
from collections.abc import MutableMapping
from hashlib import sha256
from io import BytesIO

import compress_pickle # type: ignore[import-untyped]

from ...datamodel.result import QueryResult

from ..._version import VERSION

__version__ = VERSION
__author__ = "Florian Bracq"

LOGGER: logging.Logger = logging.getLogger(__name__)


def encode_as_base64_pickle(data: QueryResult) -> str:
"""Encode data as Base64 pickle to be written to cache."""
with BytesIO() as bytes_io:
compress_pickle.dump(data, bytes_io, compression="lzma")
return base64.b64encode(bytes_io.getvalue()).decode()


def decode_base64_as_pickle(b64_string: str) -> QueryResult:
"""Decode Base64 pickle from cache to Results."""
return compress_pickle.loads(base64.b64decode(b64_string), compression="lzma")


def compute_digest(obj: MutableMapping) -> str:
"""Compute the digest from the parameters."""
str_params: str = json.dumps(obj, sort_keys=True, default=str)
LOGGER.debug("Received: %s", str_params)
digest: str = sha256(bytes(str_params, "utf-8")).hexdigest()
LOGGER.debug("Generated digest: %s", digest)
return digest
48 changes: 48 additions & 0 deletions msticpy/common/cache/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Handle caching in files."""
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

from .codec import decode_base64_as_pickle, encode_as_base64_pickle

if TYPE_CHECKING:
from ...datamodel.result import QueryResult


LOGGER: logging.Logger = logging.getLogger(__name__)
CACHE_FOLDER_NAME = "artifacts"


def write_cache(
data: QueryResult,
file_name: str,
export_folder: str = CACHE_FOLDER_NAME,
) -> None:
"""Cache content in file."""
export_path: Path = Path(export_folder)
if export_path.is_file():
export_path = export_path.parent / CACHE_FOLDER_NAME
if not export_path.exists():
export_path.mkdir(exist_ok=True, parents=True)
export_file: Path = export_path / file_name
encoded_text: str = encode_as_base64_pickle(data)
export_file.write_text(encoded_text)
LOGGER.debug("Data written to file %s", export_folder)


def read_cache(
file_name: str,
export_folder: str = CACHE_FOLDER_NAME,
) -> QueryResult:
"""Read cache content from file."""
export_path: Path = Path(export_folder)
if export_path.is_file():
export_path = export_path.parent / CACHE_FOLDER_NAME
export_file: Path = export_path / file_name
if export_file.exists():
LOGGER.debug("Found data in cache %s", export_file)
encoded_text: str = export_file.read_text()
return decode_base64_as_pickle(encoded_text)
raise FileNotFoundError
35 changes: 34 additions & 1 deletion msticpy/data/core/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import pandas as pd

from ..._version import VERSION
from ...common.cache import read_cache, write_cache
from ...common.pkg_config import get_config
from ...common.utility import export, valid_pyname
from ...datamodel.result import QueryResult
from ...nbwidgets.query_time import QueryTime
from .. import drivers
from ..drivers.driver_base import DriverBase, DriverProps
Expand Down Expand Up @@ -267,6 +269,7 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]:
)
query_name = kwargs.pop("query_name")
family = kwargs.pop("query_path")
cache_path: Optional[str] = kwargs.pop("cache_path", None)

query_source = self.query_store.get_query(
query_path=family, query_name=query_name
Expand Down Expand Up @@ -299,6 +302,7 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]:
if split_result is not None:
return split_result
# if split queries could not be created, fall back to default

query_str = query_source.create_query(
formatters=self._query_provider.formatters, **params
)
Expand All @@ -311,7 +315,36 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]:
logger.info(
"Running query '%s...' with params: %s", query_str[:40], query_options
)
return self.exec_query(query_str, query_source=query_source, **query_options)
if cache_path:
try:
result: QueryResult = read_cache(
query_options,
cache_path,
query_source.name,
)
except (ValueError, FileNotFoundError):
logger.info("Data not found in cache.")
else:
logger.info(
"Data found in cache, returning result from past execution %s.",
result.timestamp.isoformat(sep=" ", timespec="seconds"),
)
if result.raw_results is not None:
return result.raw_results

query_result: pd.DataFrame = self.exec_query(
query_str, query_source=query_source, **query_options
)

write_cache(
data=query_result,
query=query_str,
search_params=query_options,
cache_path=cache_path,
name=query_source.name,
display=kwargs.pop("display", True),
)
return query_result

def _check_for_time_params(self, params, missing) -> bool:
"""Fall back on builtin query time if no time parameters were supplied."""
Expand Down
Loading
Loading