Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations #1404

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions lib/connection/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
#
# Author: Mauro Soria

from __future__ import annotations

from socket import getaddrinfo
from typing import Any

_dns_cache = {}
_dns_cache: dict[tuple[str, int], list[Any]] = {}


def cache_dns(domain, port, addr):
def cache_dns(domain: str, port: int, addr: str) -> None:
_dns_cache[domain, port] = getaddrinfo(addr, port)


def cached_getaddrinfo(*args, **kwargs):
def cached_getaddrinfo(*args: Any, **kwargs: int) -> list[Any]:
"""
Replacement for socket.getaddrinfo, they are the same but this function
does cache the answer to improve the performance
Expand Down
42 changes: 22 additions & 20 deletions lib/connection/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#
# Author: Mauro Soria

from __future__ import annotations

import asyncio
import http.client
import random
Expand All @@ -24,7 +26,7 @@
from ssl import SSLError
import threading
import time
from typing import Generator, Optional
from typing import Any, Generator
from urllib.parse import urlparse

import httpx
Expand Down Expand Up @@ -59,12 +61,12 @@


class BaseRequester:
def __init__(self):
self._url = None
self._proxy_cred = None
def __init__(self) -> None:
self._url: str = ""
self._proxy_cred: str = ""
self._rate = 0
self.headers = CaseInsensitiveDict(options["headers"])
self.agents = []
self.agents: list[str] = []
self.session = None

self._cert = None
Expand Down Expand Up @@ -117,27 +119,27 @@ def set_proxy(self, proxy: str) -> None:
def set_proxy_auth(self, credential: str) -> None:
self._proxy_cred = credential

def is_rate_exceeded(self):
def is_rate_exceeded(self) -> bool:
return self._rate >= options["max_rate"] > 0

def decrease_rate(self):
def decrease_rate(self) -> None:
self._rate -= 1

def increase_rate(self):
def increase_rate(self) -> None:
self._rate += 1
threading.Timer(1, self.decrease_rate).start()

@property
@cached(RATE_UPDATE_DELAY)
def rate(self):
def rate(self) -> int:
return self._rate


class HTTPBearerAuth(AuthBase):
def __init__(self, token):
def __init__(self, token: str) -> None:
self.token = token

def __call__(self, request):
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

Expand All @@ -160,7 +162,7 @@ def __init__(self):
),
)

def set_auth(self, type, credential):
def set_auth(self, type: str, credential: str) -> None:
if type in ("bearer", "jwt"):
self.session.auth = HTTPBearerAuth(credential)
else:
Expand All @@ -178,7 +180,7 @@ def set_auth(self, type, credential):
self.session.auth = HttpNtlmAuth(user, password)

# :path: is expected not to start with "/"
def request(self, path, proxy=None):
def request(self, path: str, proxy: str | None = None) -> Response:
# Pause if the request rate exceeded the maximum
while self.is_rate_exceeded():
time.sleep(0.1)
Expand Down Expand Up @@ -213,13 +215,13 @@ def request(self, path, proxy=None):
prepped = self.session.prepare_request(request)
prepped.url = url

response = self.session.send(
origin_response = self.session.send(
prepped,
allow_redirects=options["follow_redirects"],
timeout=options["timeout"],
stream=True,
)
response = Response(response)
response = Response(origin_response)

log_msg = f'"{options["http_method"]} {response.url}" {response.status} - {response.length}B'

Expand Down Expand Up @@ -270,13 +272,13 @@ class HTTPXBearerAuth(httpx.Auth):
def __init__(self, token: str) -> None:
self.token = token

def auth_flow(self, request: httpx.Request) -> Generator:
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, None, None]:
request.headers["Authorization"] = f"Bearer {self.token}"
yield request


class ProxyRoatingTransport(httpx.AsyncBaseTransport):
def __init__(self, proxies, **kwargs) -> None:
def __init__(self, proxies: list[str], **kwargs: Any) -> None:
self._transports = [
httpx.AsyncHTTPTransport(proxy=proxy, **kwargs) for proxy in proxies
]
Expand All @@ -287,7 +289,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:


class AsyncRequester(BaseRequester):
def __init__(self):
def __init__(self) -> None:
super().__init__()

tpargs = {
Expand Down Expand Up @@ -340,7 +342,7 @@ def set_auth(self, type: str, credential: str) -> None:
else:
self.session.auth = HttpxNtlmAuth(user, password)

async def replay_request(self, path: str, proxy: str):
async def replay_request(self, path: str, proxy: str) -> AsyncResponse:
if self.replay_session is None:
transport = httpx.AsyncHTTPTransport(
verify=False,
Expand All @@ -357,7 +359,7 @@ async def replay_request(self, path: str, proxy: str):

# :path: is expected not to start with "/"
async def request(
self, path: str, session: Optional[httpx.AsyncClient] = None
self, path: str, session: httpx.AsyncClient | None = None
) -> AsyncResponse:
while self.is_rate_exceeded():
await asyncio.sleep(0.1)
Expand Down
33 changes: 19 additions & 14 deletions lib/connection/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
#
# Author: Mauro Soria

from __future__ import annotations

from typing import Any

import time
import httpx
import requests

from lib.core.settings import (
DEFAULT_ENCODING,
Expand All @@ -30,7 +35,7 @@


class BaseResponse:
def __init__(self, response):
def __init__(self, response: requests.Response | httpx.Response) -> None:
self.datetime = time.strftime("%Y-%m-%d %H:%M:%S")
self.url = str(response.url)
self.full_path = parse_path(self.url)
Expand All @@ -43,27 +48,27 @@ def __init__(self, response):
self.body = b""

@property
def type(self):
if "content-type" in self.headers:
return self.headers.get("content-type").split(";")[0]
def type(self) -> str:
if ct := self.headers.get("content-type"):
return ct.split(";")[0]

return UNKNOWN

@property
def length(self):
try:
return int(self.headers.get("content-length"))
except TypeError:
return len(self.body)
def length(self) -> int:
if cl := self.headers.get("content-length"):
return int(cl)

return len(self.body)

@property
def size(self):
def size(self) -> str:
return get_readable_size(self.length)

def __hash__(self):
def __hash__(self) -> int:
return hash(self.body)

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return (self.status, self.body, self.redirect) == (
other.status,
other.body,
Expand All @@ -72,7 +77,7 @@ def __eq__(self, other):


class Response(BaseResponse):
def __init__(self, response):
def __init__(self, response: requests.Response) -> None:
super().__init__(response)

for chunk in response.iter_content(chunk_size=ITER_CHUNK_SIZE):
Expand All @@ -94,7 +99,7 @@ def __init__(self, response):

class AsyncResponse(BaseResponse):
@classmethod
async def create(cls, response: httpx.Response) -> "AsyncResponse":
async def create(cls, response: httpx.Response) -> AsyncResponse:
self = cls(response)
async for chunk in response.aiter_bytes(chunk_size=ITER_CHUNK_SIZE):
self.body += chunk
Expand Down
Loading
Loading