Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize extract strings from raw bytes #84

Merged
merged 11 commits into from
Nov 8, 2024
34 changes: 26 additions & 8 deletions dfint64_patch/cross_references/cross_references_relative.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@
REFERENCE_SIZE = 4


def find_relative_cross_references_loop(
bytes_block: bytes, base_address: Rva, addresses: Iterable[int]
) -> Iterator[tuple[int, int]]:
"""
Analyse a block of bytes and try to find relative cross-references to the given objects' addresses.
Optimized hot loop, don't add extra stuff to the loop (like conversion to Rva etc.)

:param bytes_block: bytes block to analyse
:param base_address: base address of the given block (preferably this should be of some type with fast "in" check,
like set, dict, range or short tuple)
:param addresses: an iterable of destination addresses
:return: pairs of destinations and source addresses
"""
for i in range(len(bytes_block) - REFERENCE_SIZE + 1):
relative_offset = int.from_bytes(bytes_block[i : i + REFERENCE_SIZE], byteorder="little", signed=True)
destination = base_address + i + REFERENCE_SIZE + relative_offset

if destination in addresses:
yield destination, base_address + i


def find_relative_cross_references(
bytes_block: bytes,
base_address: Rva,
Expand All @@ -22,19 +43,16 @@ def find_relative_cross_references(
(e.g. `range(0x11000, 0x12000)`) or dict object.
:return: Mapping[object_rva: Rva, cross_references: List[Rva]]
"""
view = memoryview(bytes_block)
result = defaultdict(list)

if not isinstance(addresses, range | dict):
addresses = set(addresses)

for i in tqdm(range(len(bytes_block) - REFERENCE_SIZE + 1), desc="find_relative_cross_references"):
relative_offset = int.from_bytes(bytes(view[i : i + REFERENCE_SIZE]), byteorder="little", signed=True)

destination = Rva(base_address + i + REFERENCE_SIZE + relative_offset)

if destination in addresses:
result[destination].append(Rva(base_address + i))
for destination, source in tqdm(
find_relative_cross_references_loop(bytes_block, base_address, addresses),
desc="find_relative_cross_references",
):
result[destination].append(source)

return result

Expand Down
68 changes: 24 additions & 44 deletions dfint64_patch/extract_strings/from_raw_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,23 @@

from dfint64_patch.type_aliases import RVA0, Rva

forbidden: set[int] = set(b"$^@")
allowed: set[int] = set()
forbidden: str = "$^@"

ASCII_MAX_CODE = 127
ASCII_MAX_CHAR = chr(127)


def is_allowed(x: int) -> bool:
return x in allowed or (ord(" ") <= x <= ASCII_MAX_CODE and x not in forbidden)
def is_allowed(c: str) -> bool:
return " " <= c <= ASCII_MAX_CHAR and c not in forbidden


def possible_to_decode(c: bytes, encoding: str) -> bool:
try:
c.decode(encoding=encoding)
except UnicodeDecodeError:
return False
else:
return True


def check_string(buf: bytes | memoryview, encoding: str) -> tuple[int, int]:
def check_string(buf: str) -> bool:
"""
Try to decode bytes as a string in the given encoding
Check that the buffer contain letters and doesn't contain forbidden characters

:param buf: byte buffer
:param encoding: string encoding
:return: (string_length: int, number_of_letters: int)
:return: number_of_letters: int
"""

string_length = 0
number_of_letters = 0
for i, c in enumerate(buf):
if c == 0:
string_length = i
break

current_byte = bytes(buf[i : i + 1])
if not is_allowed(c) or not possible_to_decode(current_byte, encoding):
break

if current_byte.isalpha():
number_of_letters += 1

return string_length, number_of_letters
return any(c.isalpha() for c in buf) and all(is_allowed(c) for c in buf)


class ExtractedStringInfo(NamedTuple):
Expand All @@ -73,16 +48,21 @@ def extract_strings_from_raw_bytes(
:param encoding: string encoding
:return: Iterator[ExtractedStringInfo]
"""
view = memoryview(bytes_block)

i = 0
while i < len(view):
buffer_part = view[i:]
string_len, letters = check_string(buffer_part, encoding)
if string_len and letters:
string = bytes(view[i : i + string_len]).decode(encoding)
yield ExtractedStringInfo(Rva(base_address + i), string)
i += (string_len // alignment + 1) * alignment
while i < len(bytes_block):
if bytes_block[i] == b"\0":
i += alignment
continue

i += alignment
end_index = bytes_block.index(b"\0", i)
buffer_part = bytes_block[i:end_index]

try:
string = bytes(buffer_part).decode(encoding)
if check_string(string):
yield ExtractedStringInfo(Rva(base_address + i), string)
except UnicodeDecodeError:
pass

string_len = end_index - i
i += (string_len // alignment + 1) * alignment
11 changes: 6 additions & 5 deletions tests/test_extract_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@


@pytest.mark.parametrize(
("test_data", "encoding", "expected"),
("test_data", "expected"),
[
(b"12345\0", "cp437", (5, 0)),
(b"12345\xff\0", "utf-8", (0, 0)),
("12345", False),
("12345\xff", False),
("1234abc5", True),
],
)
def test_check_string(test_data: bytes, encoding: str, expected: tuple[int, int]):
assert check_string(test_data, encoding) == expected
def test_check_string(test_data: str, expected: bool):
assert check_string(test_data) == expected


@pytest.mark.parametrize(
Expand Down