diff --git a/capa/rules/__init__.py b/capa/rules/__init__.py index 7feac79e8..313011bb0 100644 --- a/capa/rules/__init__.py +++ b/capa/rules/__init__.py @@ -16,7 +16,6 @@ import binascii import collections from enum import Enum -from typing import Literal from pathlib import Path from capa.helpers import assert_never @@ -197,17 +196,22 @@ def __repr__(self): return str(self) +class ComType(Enum): + CLASS = "class" + INTERFACE = "interface" + + # COM data source https://github.com/stevemk14ebr/COM-Code-Helper/tree/master VALID_COM_TYPES = { - "class": {"db_path": "assets/classes.json.gz", "prefix": "CLSID_"}, - "interface": {"db_path": "assets/interfaces.json.gz", "prefix": "IID_"}, + ComType.CLASS: {"db_path": "assets/classes.json.gz", "prefix": "CLSID_"}, + ComType.INTERFACE: {"db_path": "assets/interfaces.json.gz", "prefix": "IID_"}, } -com_db_cache: Dict[str, Dict[str, List[str]]] = {} +com_db_cache: Dict[ComType, Dict[str, List[str]]] = {} -def load_com_database(com_type: Literal["class", "interface"]) -> Dict[str, List[str]]: - com_db_path = capa.main.get_default_root() / VALID_COM_TYPES[com_type]["db_path"] +def load_com_database(com_type: ComType) -> Dict[str, List[str]]: + com_db_path: Path = capa.main.get_default_root() / VALID_COM_TYPES[com_type]["db_path"] if com_type in com_db_cache: # If the com database is already in the cache, return it @@ -225,7 +229,7 @@ def load_com_database(com_type: Literal["class", "interface"]) -> Dict[str, List raise IOError(f"Error loading COM database from '{com_db_path}'") from e -def translate_com_feature(com_name: str, com_type: Literal["class", "interface"]) -> ceng.Or: +def translate_com_feature(com_name: str, com_type: ComType) -> ceng.Or: if com_type not in VALID_COM_TYPES: raise InvalidRule(f"Invalid COM type present {com_type}") @@ -662,11 +666,11 @@ def build_statements(d, scope: str): return feature elif key.startswith("com/"): - com_type = key[len("com/") :] - if com_type not in VALID_COM_TYPES: + com_type = str(key[len("com/") :]).upper() + if com_type not in [item.name for item in ComType]: raise InvalidRule(f"unexpected COM type: {com_type}") value, description = parse_description(d[key], key, d.get("description")) - return translate_com_feature(value, com_type) + return translate_com_feature(value, ComType[com_type]) else: Feature = parse_feature(key)