From da7e84014f575e3624a7435d44ff5bfdc895421e Mon Sep 17 00:00:00 2001 From: 4shen0ne <4shen.01@gmail.com> Date: Sat, 12 Oct 2024 15:29:49 +0800 Subject: [PATCH] add type annotations --- lib/connection/dns.py | 9 ++++--- lib/connection/requester.py | 42 +++++++++++++++-------------- lib/connection/response.py | 31 ++++++++++++--------- lib/controller/controller.py | 52 +++++++++++++++++++----------------- lib/core/data.py | 9 +++++-- lib/core/decorators.py | 20 +++++++++----- lib/core/dictionary.py | 27 ++++++++++--------- lib/core/fuzzer.py | 52 +++++++++++++++++++----------------- lib/core/installation.py | 8 +++--- lib/core/logger.py | 2 +- lib/core/options.py | 14 ++++++---- lib/core/scanner.py | 27 ++++++++++--------- lib/core/structures.py | 37 ++++++++++++++----------- lib/parse/cmdline.py | 4 +-- lib/parse/config.py | 42 +++++++++++++++++++++++++---- lib/parse/headers.py | 12 +++++---- lib/parse/nmap.py | 4 ++- lib/parse/rawrequest.py | 4 ++- lib/parse/url.py | 4 +-- lib/utils/common.py | 4 +-- lib/utils/diff.py | 2 +- lib/utils/file.py | 6 +++-- lib/utils/mimetype.py | 3 ++- 23 files changed, 251 insertions(+), 164 deletions(-) diff --git a/lib/connection/dns.py b/lib/connection/dns.py index b44229f4c..1b525535d 100755 --- a/lib/connection/dns.py +++ b/lib/connection/dns.py @@ -16,16 +16,19 @@ # # Author: Mauro Soria +from __future__ import annotations + from socket import getaddrinfo +from typing import Any -_dns_cache = {} +_dns_cache: dict[tuple[str, int], list[Any]] = {} -def cache_dns(domain, port, addr): +def cache_dns(domain: str, port: int, addr: str) -> None: _dns_cache[domain, port] = getaddrinfo(addr, port) -def cached_getaddrinfo(*args, **kwargs): +def cached_getaddrinfo(*args: Any, **kwargs: int) -> list[Any]: """ Replacement for socket.getaddrinfo, they are the same but this function does cache the answer to improve the performance diff --git a/lib/connection/requester.py b/lib/connection/requester.py index c1fb386ce..2e4ade6ad 100755 --- a/lib/connection/requester.py +++ b/lib/connection/requester.py @@ -16,6 +16,8 @@ # # Author: Mauro Soria +from __future__ import annotations + import asyncio import http.client import random @@ -24,7 +26,7 @@ from ssl import SSLError import threading import time -from typing import Generator, Optional +from typing import Any, Generator from urllib.parse import urlparse import httpx @@ -59,12 +61,12 @@ class BaseRequester: - def __init__(self): - self._url = None - self._proxy_cred = None + def __init__(self) -> None: + self._url: str = "" + self._proxy_cred: str = "" self._rate = 0 self.headers = CaseInsensitiveDict(options["headers"]) - self.agents = [] + self.agents: list[str] = [] self.session = None self._cert = None @@ -117,27 +119,27 @@ def set_proxy(self, proxy: str) -> None: def set_proxy_auth(self, credential: str) -> None: self._proxy_cred = credential - def is_rate_exceeded(self): + def is_rate_exceeded(self) -> bool: return self._rate >= options["max_rate"] > 0 - def decrease_rate(self): + def decrease_rate(self) -> None: self._rate -= 1 - def increase_rate(self): + def increase_rate(self) -> None: self._rate += 1 threading.Timer(1, self.decrease_rate).start() @property @cached(RATE_UPDATE_DELAY) - def rate(self): + def rate(self) -> int: return self._rate class HTTPBearerAuth(AuthBase): - def __init__(self, token): + def __init__(self, token: str) -> None: self.token = token - def __call__(self, request): + def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: request.headers["Authorization"] = f"Bearer {self.token}" return request @@ -160,7 +162,7 @@ def __init__(self): ), ) - def set_auth(self, type, credential): + def set_auth(self, type: str, credential: str) -> None: if type in ("bearer", "jwt"): self.session.auth = HTTPBearerAuth(credential) else: @@ -178,7 +180,7 @@ def set_auth(self, type, credential): self.session.auth = HttpNtlmAuth(user, password) # :path: is expected not to start with "/" - def request(self, path, proxy=None): + def request(self, path: str, proxy: str | None = None) -> Response: # Pause if the request rate exceeded the maximum while self.is_rate_exceeded(): time.sleep(0.1) @@ -213,13 +215,13 @@ def request(self, path, proxy=None): prepped = self.session.prepare_request(request) prepped.url = url - response = self.session.send( + origin_response = self.session.send( prepped, allow_redirects=options["follow_redirects"], timeout=options["timeout"], stream=True, ) - response = Response(response) + response = Response(origin_response) log_msg = f'"{options["http_method"]} {response.url}" {response.status} - {response.length}B' @@ -270,13 +272,13 @@ class HTTPXBearerAuth(httpx.Auth): def __init__(self, token: str) -> None: self.token = token - def auth_flow(self, request: httpx.Request) -> Generator: + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, None, None]: request.headers["Authorization"] = f"Bearer {self.token}" yield request class ProxyRoatingTransport(httpx.AsyncBaseTransport): - def __init__(self, proxies, **kwargs) -> None: + def __init__(self, proxies: list[str], **kwargs: Any) -> None: self._transports = [ httpx.AsyncHTTPTransport(proxy=proxy, **kwargs) for proxy in proxies ] @@ -287,7 +289,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: class AsyncRequester(BaseRequester): - def __init__(self): + def __init__(self) -> None: super().__init__() tpargs = { @@ -340,7 +342,7 @@ def set_auth(self, type: str, credential: str) -> None: else: self.session.auth = HttpxNtlmAuth(user, password) - async def replay_request(self, path: str, proxy: str): + async def replay_request(self, path: str, proxy: str) -> AsyncResponse: if self.replay_session is None: transport = httpx.AsyncHTTPTransport( verify=False, @@ -357,7 +359,7 @@ async def replay_request(self, path: str, proxy: str): # :path: is expected not to start with "/" async def request( - self, path: str, session: Optional[httpx.AsyncClient] = None + self, path: str, session: httpx.AsyncClient | None = None ) -> AsyncResponse: while self.is_rate_exceeded(): await asyncio.sleep(0.1) diff --git a/lib/connection/response.py b/lib/connection/response.py index b9a56fbc9..4fce969f6 100755 --- a/lib/connection/response.py +++ b/lib/connection/response.py @@ -16,7 +16,12 @@ # # Author: Mauro Soria +from __future__ import annotations + +from typing import Any + import httpx +import requests from lib.core.settings import ( DEFAULT_ENCODING, @@ -29,7 +34,7 @@ class BaseResponse: - def __init__(self, response): + def __init__(self, response: requests.Response | httpx.Response) -> None: self.url = str(response.url) self.full_path = parse_path(self.url) self.path = clean_path(self.full_path) @@ -41,23 +46,23 @@ def __init__(self, response): self.body = b"" @property - def type(self): - if "content-type" in self.headers: - return self.headers.get("content-type").split(";")[0] + def type(self) -> str: + if ct := self.headers.get("content-type"): + return ct.split(";")[0] return UNKNOWN @property - def length(self): - try: - return int(self.headers.get("content-length")) - except TypeError: - return len(self.body) + def length(self) -> int: + if cl := self.headers.get("content-length"): + return int(cl) + + return len(self.body) - def __hash__(self): + def __hash__(self) -> int: return hash(self.body) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return (self.status, self.body, self.redirect) == ( other.status, other.body, @@ -66,7 +71,7 @@ def __eq__(self, other): class Response(BaseResponse): - def __init__(self, response): + def __init__(self, response: requests.Response) -> None: super().__init__(response) for chunk in response.iter_content(chunk_size=ITER_CHUNK_SIZE): @@ -88,7 +93,7 @@ def __init__(self, response): class AsyncResponse(BaseResponse): @classmethod - async def create(cls, response: httpx.Response) -> "AsyncResponse": + async def create(cls, response: httpx.Response) -> AsyncResponse: self = cls(response) async for chunk in response.aiter_bytes(chunk_size=ITER_CHUNK_SIZE): self.body += chunk diff --git a/lib/controller/controller.py b/lib/controller/controller.py index 8abae1871..49da81523 100755 --- a/lib/controller/controller.py +++ b/lib/controller/controller.py @@ -16,6 +16,8 @@ # # Author: Mauro Soria +from __future__ import annotations + import asyncio import gc import os @@ -28,6 +30,7 @@ from urllib.parse import urlparse from lib.connection.dns import cache_dns +from lib.connection.response import Response from lib.core.data import blacklists, options from lib.core.decorators import locked from lib.core.dictionary import Dictionary, get_blacklists @@ -53,6 +56,7 @@ ) from lib.parse.rawrequest import parse_raw from lib.parse.url import clean_path, parse_path +from lib.reports.base import FileBaseReport, SQLBaseReport from lib.reports.csv_report import CSVReport from lib.reports.html_report import HTMLReport from lib.reports.json_report import JSONReport @@ -84,7 +88,7 @@ class Controller: - def __init__(self): + def __init__(self) -> None: if options["session_file"]: self._import(options["session_file"]) self.old_session = True @@ -94,7 +98,7 @@ def __init__(self): self.run() - def _import(self, session_file): + def _import(self, session_file: str) -> None: try: with open(session_file, "rb") as fd: indict, last_output, opt = unpickle(fd) @@ -108,7 +112,7 @@ def _import(self, session_file): self.__dict__ = {**indict, **vars(self)} print(last_output) - def _export(self, session_file): + def _export(self, session_file: str) -> None: # Save written output last_output = interface.buffer.rstrip() @@ -118,7 +122,7 @@ def _export(self, session_file): with open(session_file, "wb") as fd: pickle((vars(self), last_output, options), fd) - def setup(self): + def setup(self) -> None: blacklists.update(get_blacklists()) if options["raw_file"]: @@ -143,11 +147,11 @@ def setup(self): self.requester = Requester() self.dictionary = Dictionary(files=options["wordlists"]) - self.results = [] + self.results: list[Response] = [] self.start_time = time.time() - self.passed_urls = set() - self.directories = [] - self.report = None + self.passed_urls: set[str] = set() + self.directories: list[str] = [] + self.report: FileBaseReport | SQLBaseReport | None = None self.batch = False self.jobs_processed = 0 self.errors = 0 @@ -212,7 +216,7 @@ def setup(self): if options["log_file"]: interface.log_file(options["log_file"]) - def run(self): + def run(self) -> None: # match_callbacks and not_found_callbacks callback values: # - *args[0]: lib.connection.Response() object # @@ -275,7 +279,7 @@ def run(self): except Exception: interface.error("Failed to delete old session file, remove it to free some space") - def start(self): + def start(self) -> None: while self.directories: try: gc.collect() @@ -385,7 +389,7 @@ def set_target(self, url): self.requester.set_url(self.url) - def setup_batch_reports(self): + def setup_batch_reports(self) -> str: """Create batch report folder""" self.batch = True @@ -401,13 +405,13 @@ def setup_batch_reports(self): return batch_directory_path - def get_output_extension(self): + def get_output_extension(self) -> str: if options["output_format"] in ("plain", "simple"): return "txt" return options["output_format"] - def setup_reports(self): + def setup_reports(self) -> None: """Create report file""" output = options["output"] @@ -474,10 +478,10 @@ def setup_reports(self): interface.output_location(output) - def reset_consecutive_errors(self, response): + def reset_consecutive_errors(self, response: Response) -> None: self.consecutive_errors = 0 - def match_callback(self, response): + def match_callback(self, response: Response) -> None: if response.status in options["skip_on_status"]: raise SkipTargetInterrupt( f"Skipped the target due to {response.status} status code" @@ -515,7 +519,7 @@ def match_callback(self, response): self.results.append(response) self.report.save(self.results) - def update_progress_bar(self, response): + def update_progress_bar(self, response: Response) -> None: jobs_count = ( # Jobs left for unscanned targets len(options["subdirs"]) * (len(options["urls"]) - 1) @@ -534,7 +538,7 @@ def update_progress_bar(self, response): self.errors, ) - def raise_error(self, exception): + def raise_error(self, exception: RequestException) -> None: if options["exit_on_error"]: raise QuitInterrupt("Canceled due to an error") @@ -544,10 +548,10 @@ def raise_error(self, exception): if self.consecutive_errors > MAX_CONSECUTIVE_REQUEST_ERRORS: raise SkipTargetInterrupt("Too many request errors") - def append_error_log(self, exception): + def append_error_log(self, exception: RequestException) -> None: logger.exception(exception) - def handle_pause(self): + def handle_pause(self) -> None: interface.warning( "CTRL+C detected: Pausing threads, please wait...", do_save=False ) @@ -619,10 +623,10 @@ def handle_pause(self): else: raise skipexc - def is_timed_out(self): + def is_timed_out(self) -> bool: return time.time() - self.start_time > options["max_time"] > 0 - def process(self): + def process(self) -> None: while True: try: while not self.fuzzer.is_finished(): @@ -638,7 +642,7 @@ def process(self): time.sleep(0.3) - def add_directory(self, path): + def add_directory(self, path: str) -> None: """Add directory to the recursion queue""" # Pass if path is in exclusive directories @@ -660,7 +664,7 @@ def add_directory(self, path): self.passed_urls.add(url) @locked - def recur(self, path): + def recur(self, path: str) -> list[str]: dirs_count = len(self.directories) path = clean_path(path) @@ -682,7 +686,7 @@ def recur(self, path): # Return newly added directories return self.directories[dirs_count:] - def recur_for_redirect(self, path, redirect_path): + def recur_for_redirect(self, path: str, redirect_path: str) -> list[str]: if redirect_path == path + "/": return self.recur(redirect_path) diff --git a/lib/core/data.py b/lib/core/data.py index 9eff02734..647766cb5 100755 --- a/lib/core/data.py +++ b/lib/core/data.py @@ -16,8 +16,13 @@ # # Author: Mauro Soria -blacklists = {} -options = { +from __future__ import annotations + +from typing import Any + +# we can't import `Dictionary` due to a circular import +blacklists: dict[int, Any] = {} +options: dict[str, Any] = { "urls": [], "urls_file": None, "stdin_urls": None, diff --git a/lib/core/decorators.py b/lib/core/decorators.py index 1f5c31e13..f56de493f 100755 --- a/lib/core/decorators.py +++ b/lib/core/decorators.py @@ -16,20 +16,28 @@ # # Author: Mauro Soria +from __future__ import annotations + import threading from functools import wraps from time import time +from typing import Any, Callable, TypeVar +from typing_extensions import ParamSpec _lock = threading.Lock() -_cache = {} +_cache: dict[int, tuple[float, Any]] = {} _cache_lock = threading.Lock() +# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators +P = ParamSpec("P") +T = TypeVar("T") + -def cached(timeout=100): - def _cached(func): +def cached(timeout: int | float = 100) -> Callable[..., Any]: + def _cached(func: Callable[P, T]) -> Callable[P, T]: @wraps(func) - def with_caching(*args, **kwargs): + def with_caching(*args: P.args, **kwargs: P.kwargs) -> T: key = id(func) for arg in args: key += id(arg) @@ -51,8 +59,8 @@ def with_caching(*args, **kwargs): return _cached -def locked(func): - def with_locking(*args, **kwargs): +def locked(func: Callable[P, T]) -> Callable[P, T]: + def with_locking(*args: P.args, **kwargs: P.kwargs) -> T: with _lock: return func(*args, **kwargs) diff --git a/lib/core/dictionary.py b/lib/core/dictionary.py index 5d6f47098..1424de298 100755 --- a/lib/core/dictionary.py +++ b/lib/core/dictionary.py @@ -16,7 +16,10 @@ # # Author: Mauro Soria +from __future__ import annotations + import re +from typing import Any, Iterator from lib.core.data import options from lib.core.decorators import locked @@ -34,7 +37,7 @@ # Get ignore paths for status codes. # Reference: https://github.com/maurosoria/dirsearch#Blacklist -def get_blacklists(): +def get_blacklists() -> dict[int, Dictionary]: blacklists = {} for status in [400, 403, 500]: @@ -56,16 +59,16 @@ def get_blacklists(): class Dictionary: - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: self._index = 0 self._items = self.generate(**kwargs) @property - def index(self): + def index(self) -> int: return self._index @locked - def __next__(self): + def __next__(self) -> str: try: path = self._items[self._index] except IndexError: @@ -75,22 +78,22 @@ def __next__(self): return path - def __contains__(self, item): + def __contains__(self, item: str) -> bool: return item in self._items - def __getstate__(self): + def __getstate__(self) -> tuple[list[str], int]: return self._items, self._index - def __setstate__(self, state): + def __setstate__(self, state: tuple[list[str], int]) -> None: self._items, self._index = state - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._items) - def __len__(self): + def __len__(self) -> int: return len(self._items) - def generate(self, files=[], is_blacklist=False): + def generate(self, files: list[str] = [], is_blacklist: bool = False) -> list[str]: """ Dictionary.generate() behaviour @@ -192,7 +195,7 @@ def generate(self, files=[], is_blacklist=False): else: return list(wordlist) - def is_valid(self, path): + def is_valid(self, path: str) -> bool: # Skip comments and empty lines if not path or path.startswith("#"): return False @@ -206,5 +209,5 @@ def is_valid(self, path): return True - def reset(self): + def reset(self) -> None: self._index = 0 diff --git a/lib/core/fuzzer.py b/lib/core/fuzzer.py index 29a1519fe..9a5055006 100755 --- a/lib/core/fuzzer.py +++ b/lib/core/fuzzer.py @@ -16,13 +16,15 @@ # # Author: Mauro Soria +from __future__ import annotations + import asyncio import re import threading import time -from typing import Callable, Generator, Tuple +from typing import Any, Callable, Generator -from lib.connection.requester import BaseRequester +from lib.connection.requester import AsyncRequester, BaseRequester, Requester from lib.connection.response import BaseResponse from lib.core.data import blacklists, options from lib.core.dictionary import Dictionary @@ -45,20 +47,20 @@ def __init__( requester: BaseRequester, dictionary: Dictionary, *, - match_callbacks: Tuple[Callable] = (), - not_found_callbacks: Tuple[Callable] = (), - error_callbacks: Tuple[Callable] = (), + match_callbacks: tuple[Callable[[BaseResponse], Any], ...], + not_found_callbacks: tuple[Callable[[BaseResponse], Any], ...], + error_callbacks: tuple[Callable[[RequestException], Any], ...], ) -> None: self._scanned = set() self._requester = requester self._dictionary = dictionary - self._base_path = None - self.exc = None + self._base_path: str = "" + self.exc: Exception | None = None self.match_callbacks = match_callbacks self.not_found_callbacks = not_found_callbacks self.error_callbacks = error_callbacks - self.scanners = { + self.scanners: dict[str, dict[str, Scanner]] = { "default": {}, "prefixes": {}, "suffixes": {}, @@ -130,12 +132,12 @@ def is_excluded(resp: BaseResponse) -> bool: class Fuzzer(BaseFuzzer): def __init__( self, - requester: BaseRequester, + requester: Requester, dictionary: Dictionary, *, - match_callbacks: Tuple[Callable] = (), - not_found_callbacks: Tuple[Callable] = (), - error_callbacks: Tuple[Callable] = (), + match_callbacks: tuple[Callable[[BaseResponse], Any], ...], + not_found_callbacks: tuple[Callable[[BaseResponse], Any], ...], + error_callbacks: tuple[Callable[[RequestException], Any], ...], ) -> None: super().__init__( requester, @@ -149,7 +151,7 @@ def __init__( self._quit_event = threading.Event() self._pause_semaphore = threading.Semaphore(0) - def setup_scanners(self): + def setup_scanners(self) -> None: # Default scanners (wildcard testers) self.scanners["default"].update( { @@ -190,7 +192,7 @@ def setup_scanners(self): context=f"/{self._base_path}***.{extension}", ) - def setup_threads(self): + def setup_threads(self) -> None: if self._threads: self._threads = [] @@ -199,7 +201,7 @@ def setup_threads(self): new_thread.daemon = True self._threads.append(new_thread) - def start(self): + def start(self) -> None: self.setup_scanners() self.setup_threads() self.play() @@ -207,7 +209,7 @@ def start(self): for thread in self._threads: thread.start() - def is_finished(self): + def is_finished(self) -> bool: if self.exc: raise self.exc @@ -217,21 +219,21 @@ def is_finished(self): return True - def play(self): + def play(self) -> None: self._play_event.set() - def pause(self): + def pause(self) -> None: self._play_event.clear() # Wait for all threads to stop for thread in self._threads: if thread.is_alive(): self._pause_semaphore.acquire() - def quit(self): + def quit(self) -> None: self._quit_event.set() self.play() - def scan(self, path, scanners): + def scan(self, path: str, scanners: Generator[Scanner, None, None]) -> None: # Avoid scanned paths from being re-scanned if path in self._scanned: return @@ -267,7 +269,7 @@ def scan(self, path, scanners): ) self.scan(path_, self.get_scanners_for(path_)) - def thread_proc(self): + def thread_proc(self) -> None: logger.info(f'THREAD-{threading.get_ident()} started"') while True: @@ -301,12 +303,12 @@ def thread_proc(self): class AsyncFuzzer(BaseFuzzer): def __init__( self, - requester: BaseRequester, + requester: AsyncRequester, dictionary: Dictionary, *, - match_callbacks: Tuple[Callable] = (), - not_found_callbacks: Tuple[Callable] = (), - error_callbacks: Tuple[Callable] = (), + match_callbacks: tuple[Callable[[BaseResponse], Any], ...], + not_found_callbacks: tuple[Callable[[BaseResponse], Any], ...], + error_callbacks: tuple[Callable[[RequestException], Any], ...], ) -> None: super().__init__( requester, diff --git a/lib/core/installation.py b/lib/core/installation.py index a1d80be81..a72a57947 100755 --- a/lib/core/installation.py +++ b/lib/core/installation.py @@ -17,6 +17,8 @@ # # Author: Mauro Soria +from __future__ import annotations + import subprocess import sys import pkg_resources @@ -28,7 +30,7 @@ REQUIREMENTS_FILE = f"{SCRIPT_PATH}/requirements.txt" -def get_dependencies(): +def get_dependencies() -> list[str]: try: return FileUtils.get_lines(REQUIREMENTS_FILE) except FileNotFoundError: @@ -37,11 +39,11 @@ def get_dependencies(): # Check if all dependencies are satisfied -def check_dependencies(): +def check_dependencies() -> None: pkg_resources.require(get_dependencies()) -def install_dependencies(): +def install_dependencies() -> None: try: subprocess.check_output( [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_FILE], diff --git a/lib/core/logger.py b/lib/core/logger.py index cbd2872d3..ce39e5556 100755 --- a/lib/core/logger.py +++ b/lib/core/logger.py @@ -27,7 +27,7 @@ logger.disabled = True -def enable_logging(): +def enable_logging() -> None: logger.disabled = False formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s') handler = RotatingFileHandler(options["log_file"], maxBytes=options["log_file_size"]) diff --git a/lib/core/options.py b/lib/core/options.py index 04a573baa..8882589b6 100755 --- a/lib/core/options.py +++ b/lib/core/options.py @@ -16,6 +16,10 @@ # # Author: Mauro Soria +from __future__ import annotations + +from optparse import Values +from typing import Any from lib.core.settings import ( AUTHENTICATION_TYPES, COMMON_EXTENSIONS, @@ -31,7 +35,7 @@ from lib.parse.nmap import parse_nmap -def parse_options(): +def parse_options() -> dict[str, Any]: opt = parse_config(parse_arguments()) if opt.session_file: @@ -206,11 +210,11 @@ def parse_options(): return vars(opt) -def _parse_status_codes(str_): +def _parse_status_codes(str_: str) -> set[int]: if not str_: return set() - status_codes = set() + status_codes: set[int] = set() for status_code in str_.split(","): try: @@ -226,7 +230,7 @@ def _parse_status_codes(str_): return status_codes -def _access_file(path): +def _access_file(path: str) -> File: with File(path) as fd: if not fd.exists(): print(f"{path} does not exist") @@ -243,7 +247,7 @@ def _access_file(path): return fd -def parse_config(opt): +def parse_config(opt: Values) -> Values: config = ConfigParser() config.read(opt.config) diff --git a/lib/core/scanner.py b/lib/core/scanner.py index ba25c81da..730a215af 100755 --- a/lib/core/scanner.py +++ b/lib/core/scanner.py @@ -16,10 +16,12 @@ # # Author: Mauro Soria +from __future__ import annotations + import asyncio import re import time -from typing import Optional +from typing import Any from urllib.parse import unquote from lib.connection.requester import AsyncRequester, BaseRequester, Requester @@ -41,7 +43,7 @@ def __init__( self, requester: BaseRequester, path: str = "", - tested: dict = {}, + tested: dict[str, Any] = {}, context: str = "all cases", ) -> None: self.path = path @@ -84,7 +86,7 @@ def check(self, path: str, response: BaseResponse) -> bool: return True - def get_duplicate(self, response: BaseResponse) -> Optional["BaseScanner"]: + def get_duplicate(self, response: BaseResponse) -> BaseScanner | None: for category in self.tested: for tester in self.tested[category].values(): if response == tester.response: @@ -92,7 +94,7 @@ def get_duplicate(self, response: BaseResponse) -> Optional["BaseScanner"]: return None - def is_wildcard(self, response): + def is_wildcard(self, response: BaseResponse) -> bool: """Check if response is similar to wildcard response""" # Compare 2 binary responses (Response.content is empty if the body is binary) @@ -102,7 +104,7 @@ def is_wildcard(self, response): return self.content_parser.compare_to(response.content) @staticmethod - def generate_redirect_regex(first_loc, first_path, second_loc, second_path): + def generate_redirect_regex(first_loc: str, first_path: str, second_loc: str, second_path: str) -> str: """ From 2 redirects of wildcard responses, generate a regexp that matches every wildcard redirect. @@ -128,14 +130,15 @@ class Scanner(BaseScanner): def __init__( self, requester: Requester, + *, path: str = "", - tested: dict = {}, + tested: dict[str, dict[str, Scanner]] = {}, context: str = "all cases", ) -> None: super().__init__(requester, path, tested, context) self.setup() - def setup(self): + def setup(self) -> None: """ Generate wildcard response information containers, this will be used to compare with other path responses @@ -149,9 +152,8 @@ def setup(self): self.response = first_response time.sleep(options["delay"]) - duplicate = self.get_duplicate(first_response) # Another test was performed before and has the same response as this - if duplicate: + if duplicate := self.get_duplicate(first_response): self.content_parser = duplicate.content_parser self.wildcard_redirect_regex = duplicate.wildcard_redirect_regex logger.debug(f'Skipped the second test for "{self.context}"') @@ -184,8 +186,9 @@ class AsyncScanner(BaseScanner): def __init__( self, requester: AsyncRequester, + *, path: str = "", - tested: dict = {}, + tested: dict[str, dict[str, AsyncScanner]] = {}, context: str = "all cases", ) -> None: super().__init__(requester, path, tested, context) @@ -196,9 +199,9 @@ async def create( requester: AsyncRequester, *, path: str = "", - tested: dict = {}, + tested: dict[str, dict[str, AsyncScanner]] = {}, context: str = "all cases", - ) -> "Scanner": + ) -> AsyncScanner: self = cls(requester, path=path, tested=tested, context=context) await self.setup() return self diff --git a/lib/core/structures.py b/lib/core/structures.py index afb61d0e9..df2130f2e 100755 --- a/lib/core/structures.py +++ b/lib/core/structures.py @@ -16,63 +16,68 @@ # # Author: Mauro Soria +from __future__ import annotations + +from typing import Any, Iterator + + class CaseInsensitiveDict(dict): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._convert_keys() - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: if isinstance(key, str): key = key.lower() super().__setitem__(key.lower(), value) - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: if isinstance(key, str): key = key.lower() return super().__getitem__(key.lower()) - def _convert_keys(self): + def _convert_keys(self) -> None: for key in list(self.keys()): value = super().pop(key) self.__setitem__(key, value) class OrderedSet: - def __init__(self, items=[]): - self._data = dict() + def __init__(self, items: list[Any] = []) -> None: + self._data: dict[Any, Any] = dict() for item in items: self._data[item] = None - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: return item in self._data - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._data.keys() == other._data.keys() - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(list(self._data)) - def __len__(self): + def __len__(self) -> int: return len(self._data) - def add(self, item): + def add(self, item: Any) -> None: self._data[item] = None - def clear(self): + def clear(self) -> None: self._data.clear() - def discard(self, item): + def discard(self, item: Any) -> None: self._data.pop(item, None) - def pop(self): + def pop(self) -> None: self._data.popitem() - def remove(self, item): + def remove(self, item: Any) -> None: del self._data[item] - def update(self, items): + def update(self, items: list[Any]) -> None: for item in items: self.add(item) diff --git a/lib/parse/cmdline.py b/lib/parse/cmdline.py index fb322307b..6cd19403e 100755 --- a/lib/parse/cmdline.py +++ b/lib/parse/cmdline.py @@ -16,7 +16,7 @@ # # Author: Mauro Soria -from optparse import OptionParser, OptionGroup +from optparse import OptionParser, OptionGroup, Values from lib.core.settings import ( @@ -27,7 +27,7 @@ from lib.utils.common import get_config_file -def parse_arguments(): +def parse_arguments() -> Values: usage = "Usage: %prog [-u|--url] target [-e|--extensions] extensions [options]" epilog = "See 'config.ini' for the example configuration file" parser = OptionParser(usage=usage, epilog=epilog, version=f"dirsearch v{VERSION}") diff --git a/lib/parse/config.py b/lib/parse/config.py index c03781354..c16f962a4 100755 --- a/lib/parse/config.py +++ b/lib/parse/config.py @@ -16,12 +16,20 @@ # # Author: Mauro Soria +from __future__ import annotations + import configparser import json class ConfigParser(configparser.ConfigParser): - def safe_get(self, section, option, default=None, allowed=None): + def safe_get( + self, + section: str, + option: str, + default: str | None = None, + allowed: tuple[str, ...] | None = None, + ) -> str | None: try: value = super().get(section, option) @@ -32,7 +40,13 @@ def safe_get(self, section, option, default=None, allowed=None): except (configparser.NoSectionError, configparser.NoOptionError): return default - def safe_getfloat(self, section, option, default=0, allowed=None): + def safe_getfloat( + self, + section: str, + option: str, + default: float = 0.0, + allowed: tuple[float, ...] | None = None, + ) -> float: try: value = super().getfloat(section, option) @@ -43,7 +57,13 @@ def safe_getfloat(self, section, option, default=0, allowed=None): except (configparser.NoSectionError, configparser.NoOptionError): return default - def safe_getboolean(self, section, option, default=False, allowed=None): + def safe_getboolean( + self, + section: str, + option: str, + default: bool = False, + allowed: tuple[bool, ...] | None = None, + ) -> bool: try: value = super().getboolean(section, option) @@ -54,7 +74,13 @@ def safe_getboolean(self, section, option, default=False, allowed=None): except (configparser.NoSectionError, configparser.NoOptionError): return default - def safe_getint(self, section, option, default=0, allowed=None): + def safe_getint( + self, + section: str, + option: str, + default: int = 0, + allowed: tuple[int, ...] | None = None, + ) -> int: try: value = super().getint(section, option) @@ -65,7 +91,13 @@ def safe_getint(self, section, option, default=0, allowed=None): except (configparser.NoSectionError, configparser.NoOptionError): return default - def safe_getlist(self, section, option, default=[], allowed=None): + def safe_getlist( + self, + section: str, + option: str, + default: list[str] = [], + allowed: tuple[str, ...] | None = None, + ) -> list[str]: try: try: value = json.loads(super().get(section, option)) diff --git a/lib/parse/headers.py b/lib/parse/headers.py index bb61c0e42..7a6bdc32c 100755 --- a/lib/parse/headers.py +++ b/lib/parse/headers.py @@ -16,6 +16,8 @@ # # Author: Mauro Soria +from __future__ import annotations + from email.parser import BytesParser from lib.core.settings import NEW_LINE @@ -23,7 +25,7 @@ class HeadersParser: - def __init__(self, headers): + def __init__(self, headers: str | dict[str, str]) -> None: self.str = self.dict = headers if isinstance(headers, str): @@ -34,18 +36,18 @@ def __init__(self, headers): self.headers = CaseInsensitiveDict(self.dict) - def get(self, key): + def get(self, key: str) -> str: return self.headers[key] @staticmethod - def str_to_dict(headers): + def str_to_dict(headers: str) -> dict[str, str]: if not headers: return {} return dict(BytesParser().parsebytes(headers.encode())) @staticmethod - def dict_to_str(headers): + def dict_to_str(headers: dict[str, str]) -> str: if not headers: return @@ -54,5 +56,5 @@ def dict_to_str(headers): def __iter__(self): return iter(self.headers.items()) - def __str__(self): + def __str__(self) -> str: return self.str diff --git a/lib/parse/nmap.py b/lib/parse/nmap.py index b1223fc27..0cc1e9542 100644 --- a/lib/parse/nmap.py +++ b/lib/parse/nmap.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import xml.etree.ElementTree as ET -def parse_nmap(file): +def parse_nmap(file: str) -> list[str]: root = ET.parse(file).getroot() targets = [] for host in root.iter("host"): diff --git a/lib/parse/rawrequest.py b/lib/parse/rawrequest.py index 901dcf1bc..a8c01f2de 100755 --- a/lib/parse/rawrequest.py +++ b/lib/parse/rawrequest.py @@ -16,13 +16,15 @@ # # Author: Mauro Soria +from __future__ import annotations + from lib.core.exceptions import InvalidRawRequest from lib.core.logger import logger from lib.parse.headers import HeadersParser from lib.utils.file import File -def parse_raw(raw_file): +def parse_raw(raw_file: str) -> tuple[list[str], str, dict[str, str], str | None]: with File(raw_file) as fd: raw_content = fd.read() diff --git a/lib/parse/url.py b/lib/parse/url.py index 23689ef14..037857afd 100755 --- a/lib/parse/url.py +++ b/lib/parse/url.py @@ -19,7 +19,7 @@ from lib.utils.common import lstrip_once -def clean_path(path, keep_queries=False, keep_fragment=False): +def clean_path(path: str, keep_queries: bool = False, keep_fragment: bool = False) -> str: if not keep_fragment: path = path.split("#")[0] if not keep_queries: @@ -28,7 +28,7 @@ def clean_path(path, keep_queries=False, keep_fragment=False): return path -def parse_path(value): +def parse_path(value: str) -> str: try: scheme, url = value.split("//", 1) if ( diff --git a/lib/utils/common.py b/lib/utils/common.py index 3648569e2..e57badf02 100644 --- a/lib/utils/common.py +++ b/lib/utils/common.py @@ -39,7 +39,7 @@ def get_config_file(): return os.environ.get("DIRSEARCH_CONFIG") or FileUtils.build_path(SCRIPT_PATH, "config.ini") -def safequote(string_): +def safequote(string_: str) -> str: return quote(string_, safe=URL_SAFE_CHARS) @@ -88,7 +88,7 @@ def human_size(num): return f"{num}TB" -def is_binary(bytes): +def is_binary(bytes) -> bool: return bool(bytes.translate(None, TEXT_CHARS)) diff --git a/lib/utils/diff.py b/lib/utils/diff.py index 4fcb57243..f78f6d7be 100755 --- a/lib/utils/diff.py +++ b/lib/utils/diff.py @@ -62,7 +62,7 @@ def get_static_patterns(patterns): return [pattern for pattern in patterns if pattern.startswith(" ")] -def generate_matching_regex(string1, string2): +def generate_matching_regex(string1: str, string2: str) -> str: start = "^" end = "$" diff --git a/lib/utils/file.py b/lib/utils/file.py index f4eec90ce..63575b294 100755 --- a/lib/utils/file.py +++ b/lib/utils/file.py @@ -16,6 +16,8 @@ # # Author: Mauro Soria +from __future__ import annotations + import os import os.path @@ -59,7 +61,7 @@ def __exit__(self, type, value, tb): class FileUtils: @staticmethod - def build_path(*path_components): + def build_path(*path_components: str) -> str: if path_components: path = os.path.join(*path_components) else: @@ -110,7 +112,7 @@ def get_files(cls, directory): return files @staticmethod - def get_lines(file_name): + def get_lines(file_name: str) -> list[str]: with open(file_name, "r", errors="replace") as fd: return fd.read().splitlines() diff --git a/lib/utils/mimetype.py b/lib/utils/mimetype.py index 6a4c31e17..c9c052c17 100755 --- a/lib/utils/mimetype.py +++ b/lib/utils/mimetype.py @@ -18,6 +18,7 @@ import re import json +from typing_extensions import LiteralString from defusedxml import ElementTree @@ -51,7 +52,7 @@ def is_query_string(content): return False -def guess_mimetype(content): +def guess_mimetype(content) -> LiteralString: if MimeTypeUtils.is_json(content): return "application/json" elif MimeTypeUtils.is_xml(content):