Skip to content

Commit

Permalink
feat: auto-generate ruleset cache on source change
Browse files Browse the repository at this point in the history
  • Loading branch information
fariss committed Jun 7, 2024
1 parent 76a4a58 commit cc3d208
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions capa/rules/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import sys
import zlib
import pickle
import shutil
import hashlib
import logging
import subprocess
from typing import List, Optional
from pathlib import Path
from dataclasses import dataclass
Expand All @@ -26,6 +28,58 @@
CacheIdentifier = str


def is_dev_environment() -> bool:
if getattr(sys, "frozen", False):
# running as a PyInstaller executable
return False

if "site-packages" in __file__:
# running from a site-packages installation
return False

if not shutil.which("git"):
# git is found, but might not be always be in PATH
# we should handle this case
return False

return True


def get_modified_files() -> List[str]:
try:
# use git status to retrieve tracked modified files
result = subprocess.run(
["git", "--no-pager", "status", "--porcelain", "--untracked-files=no"],
capture_output=True,
text=True,
check=True,
)

# retrieve .py source files
# ' M': the file has staged modifications
# 'M ': the file has unstaged modifications
# 'MM': the file has both staged and unstaged modifications
files = []
for line in result.stdout.splitlines():
if line.startswith(("M ", "MM", " M")) and line.endswith(".py"):
file_path = line[3:]
files.append(file_path)

return files
except (subprocess.CalledProcessError, FileNotFoundError):
return []


def get_git_commit_hash() -> Optional[str]:
try:
result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True)
commit_hash = result.stdout.strip()
logger.debug("git commit hash %s", commit_hash)
return commit_hash
except (subprocess.CalledProcessError, FileNotFoundError):
return None


def compute_cache_identifier(rule_content: List[bytes]) -> CacheIdentifier:
hash = hashlib.sha256()

Expand Down Expand Up @@ -107,6 +161,41 @@ def get_ruleset_content(ruleset: capa.rules.RuleSet) -> List[bytes]:

def compute_ruleset_cache_identifier(ruleset: capa.rules.RuleSet) -> CacheIdentifier:
rule_contents = get_ruleset_content(ruleset)

if is_dev_environment():
modified_files = get_modified_files()
commit_hash = get_git_commit_hash()

if modified_files or commit_hash:
hash = hashlib.sha256()
hash.update(capa.version.__version__.encode("utf-8"))
hash.update(b"\x00")

for file in modified_files:
try:
with Path(file).open("rb") as f:
file_content = f.read()
logger.debug("found modified source py %s", file)
hash.update(file_content)
hash.update(b"\x00")
except FileNotFoundError as e:
logger.error("modified file not found: %s", file)
logger.error("%s", e)

if commit_hash:
hash.update(commit_hash.encode("ascii"))
hash.update(b"\x00")

# include the hash of the rule contents
rule_hashes = sorted([hashlib.sha256(buf).hexdigest() for buf in rule_contents])
for rule_hash in rule_hashes:
hash.update(rule_hash.encode("ascii"))
hash.update(b"\x00")

logger.debug(
"developer environment detected, ruleset cache will be auto-generated upon each source modification"
)
return hash.hexdigest()
return compute_cache_identifier(rule_contents)


Expand Down

0 comments on commit cc3d208

Please sign in to comment.