diff --git a/dfint64_patch/cross_references/cross_references_relative.py b/dfint64_patch/cross_references/cross_references_relative.py index b4cd588..0f5d1ba 100644 --- a/dfint64_patch/cross_references/cross_references_relative.py +++ b/dfint64_patch/cross_references/cross_references_relative.py @@ -9,6 +9,27 @@ REFERENCE_SIZE = 4 +def find_relative_cross_references_loop( + bytes_block: bytes, base_address: Rva, addresses: Iterable[int] +) -> Iterator[tuple[int, int]]: + """ + Analyse a block of bytes and try to find relative cross-references to the given objects' addresses. + Optimized hot loop, don't add extra stuff to the loop (like conversion to Rva etc.) + + :param bytes_block: bytes block to analyse + :param base_address: base address of the given block (preferably this should be of some type with fast "in" check, + like set, dict, range or short tuple) + :param addresses: an iterable of destination addresses + :return: pairs of destinations and source addresses + """ + for i in range(len(bytes_block) - REFERENCE_SIZE + 1): + relative_offset = int.from_bytes(bytes_block[i : i + REFERENCE_SIZE], byteorder="little", signed=True) + destination = base_address + i + REFERENCE_SIZE + relative_offset + + if destination in addresses: + yield destination, base_address + i + + def find_relative_cross_references( bytes_block: bytes, base_address: Rva, @@ -22,19 +43,16 @@ def find_relative_cross_references( (e.g. `range(0x11000, 0x12000)`) or dict object. :return: Mapping[object_rva: Rva, cross_references: List[Rva]] """ - view = memoryview(bytes_block) result = defaultdict(list) if not isinstance(addresses, range | dict): addresses = set(addresses) - for i in tqdm(range(len(bytes_block) - REFERENCE_SIZE + 1), desc="find_relative_cross_references"): - relative_offset = int.from_bytes(bytes(view[i : i + REFERENCE_SIZE]), byteorder="little", signed=True) - - destination = Rva(base_address + i + REFERENCE_SIZE + relative_offset) - - if destination in addresses: - result[destination].append(Rva(base_address + i)) + for destination, source in tqdm( + find_relative_cross_references_loop(bytes_block, base_address, addresses), + desc="find_relative_cross_references", + ): + result[destination].append(source) return result diff --git a/dfint64_patch/extract_strings/from_raw_bytes.py b/dfint64_patch/extract_strings/from_raw_bytes.py index e143ed9..bf81ea9 100644 --- a/dfint64_patch/extract_strings/from_raw_bytes.py +++ b/dfint64_patch/extract_strings/from_raw_bytes.py @@ -10,48 +10,23 @@ from dfint64_patch.type_aliases import RVA0, Rva -forbidden: set[int] = set(b"$^@") -allowed: set[int] = set() +forbidden: str = "$^@" -ASCII_MAX_CODE = 127 +ASCII_MAX_CHAR = chr(127) -def is_allowed(x: int) -> bool: - return x in allowed or (ord(" ") <= x <= ASCII_MAX_CODE and x not in forbidden) +def is_allowed(c: str) -> bool: + return " " <= c <= ASCII_MAX_CHAR and c not in forbidden -def possible_to_decode(c: bytes, encoding: str) -> bool: - try: - c.decode(encoding=encoding) - except UnicodeDecodeError: - return False - else: - return True - - -def check_string(buf: bytes | memoryview, encoding: str) -> tuple[int, int]: +def check_string(buf: str) -> bool: """ - Try to decode bytes as a string in the given encoding + Check that the buffer contain letters and doesn't contain forbidden characters + :param buf: byte buffer - :param encoding: string encoding - :return: (string_length: int, number_of_letters: int) + :return: number_of_letters: int """ - - string_length = 0 - number_of_letters = 0 - for i, c in enumerate(buf): - if c == 0: - string_length = i - break - - current_byte = bytes(buf[i : i + 1]) - if not is_allowed(c) or not possible_to_decode(current_byte, encoding): - break - - if current_byte.isalpha(): - number_of_letters += 1 - - return string_length, number_of_letters + return any(c.isalpha() for c in buf) and all(is_allowed(c) for c in buf) class ExtractedStringInfo(NamedTuple): @@ -73,16 +48,21 @@ def extract_strings_from_raw_bytes( :param encoding: string encoding :return: Iterator[ExtractedStringInfo] """ - view = memoryview(bytes_block) - i = 0 - while i < len(view): - buffer_part = view[i:] - string_len, letters = check_string(buffer_part, encoding) - if string_len and letters: - string = bytes(view[i : i + string_len]).decode(encoding) - yield ExtractedStringInfo(Rva(base_address + i), string) - i += (string_len // alignment + 1) * alignment + while i < len(bytes_block): + if bytes_block[i] == b"\0": + i += alignment continue - i += alignment + end_index = bytes_block.index(b"\0", i) + buffer_part = bytes_block[i:end_index] + + try: + string = bytes(buffer_part).decode(encoding) + if check_string(string): + yield ExtractedStringInfo(Rva(base_address + i), string) + except UnicodeDecodeError: + pass + + string_len = end_index - i + i += (string_len // alignment + 1) * alignment diff --git a/tests/test_extract_strings.py b/tests/test_extract_strings.py index 5633e79..1636dba 100644 --- a/tests/test_extract_strings.py +++ b/tests/test_extract_strings.py @@ -19,14 +19,15 @@ @pytest.mark.parametrize( - ("test_data", "encoding", "expected"), + ("test_data", "expected"), [ - (b"12345\0", "cp437", (5, 0)), - (b"12345\xff\0", "utf-8", (0, 0)), + ("12345", False), + ("12345\xff", False), + ("1234abc5", True), ], ) -def test_check_string(test_data: bytes, encoding: str, expected: tuple[int, int]): - assert check_string(test_data, encoding) == expected +def test_check_string(test_data: str, expected: bool): + assert check_string(test_data) == expected @pytest.mark.parametrize(