Skip to content

Commit

Permalink
Merge pull request #131 from scrapy-plugins/scrapy-typing
Browse files Browse the repository at this point in the history
Fix typing issues for typed Scrapy.
  • Loading branch information
wRAR authored Nov 19, 2024
2 parents 64ad254 + 380fefd commit 5fbf0bc
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 49 deletions.
24 changes: 12 additions & 12 deletions scrapy_zyte_api/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import copy
from logging import getLogger
from os import environ
from typing import Any, Dict, List, Mapping, Optional, Set, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
from warnings import warn

from scrapy import Request
Expand Down Expand Up @@ -290,7 +290,7 @@ def _iter_headers(
api_params: Dict[str, Any],
request: Request,
header_parameter: str,
):
) -> Iterable[Tuple[bytes, bytes, bytes]]:
headers = api_params.get(header_parameter)
if headers not in (None, True):
logger.warning(
Expand All @@ -306,8 +306,8 @@ def _iter_headers(
continue
decoded_k = k.decode()
lowercase_k = k.strip().lower()
v = b",".join(v)
decoded_v = v.decode()
joined_v = b",".join(v)
decoded_v = joined_v.decode()

if lowercase_k.startswith(b"x-crawlera-"):
for spm_header_suffix, zapi_request_param in (
Expand Down Expand Up @@ -435,7 +435,7 @@ def _iter_headers(
)
continue

yield k, lowercase_k, v
yield k, lowercase_k, joined_v


def _map_custom_http_request_headers(
Expand All @@ -461,7 +461,7 @@ def _map_request_headers(
*,
api_params: Dict[str, Any],
request: Request,
browser_headers: Dict[str, str],
browser_headers: Dict[bytes, str],
browser_ignore_headers: SKIP_HEADER_T,
):
request_headers = {}
Expand All @@ -477,7 +477,7 @@ def _map_request_headers(
lowercase_k
] not in (ANY_VALUE, v):
logger.warning(
f"Request {request} defines header {k}, which "
f"Request {request} defines header {k.decode()}, which "
f"cannot be mapped into the Zyte API requestHeaders "
f"parameter. See the ZYTE_API_BROWSER_HEADERS setting."
)
Expand All @@ -500,7 +500,7 @@ def _set_request_headers_from_request(
api_params: Dict[str, Any],
request: Request,
skip_headers: SKIP_HEADER_T,
browser_headers: Dict[str, str],
browser_headers: Dict[bytes, str],
browser_ignore_headers: SKIP_HEADER_T,
):
"""Updates *api_params*, in place, based on *request*."""
Expand Down Expand Up @@ -727,7 +727,7 @@ def _update_api_params_from_request(
default_params: Dict[str, Any],
meta_params: Dict[str, Any],
skip_headers: SKIP_HEADER_T,
browser_headers: Dict[str, str],
browser_headers: Dict[bytes, str],
browser_ignore_headers: SKIP_HEADER_T,
cookies_enabled: bool,
cookie_jars: Optional[Dict[Any, CookieJar]],
Expand Down Expand Up @@ -859,7 +859,7 @@ def _get_automap_params(
default_enabled: bool,
default_params: Dict[str, Any],
skip_headers: SKIP_HEADER_T,
browser_headers: Dict[str, str],
browser_headers: Dict[bytes, str],
browser_ignore_headers: SKIP_HEADER_T,
cookies_enabled: bool,
cookie_jars: Optional[Dict[Any, CookieJar]],
Expand Down Expand Up @@ -906,7 +906,7 @@ def _get_api_params(
transparent_mode: bool,
automap_params: Dict[str, Any],
skip_headers: SKIP_HEADER_T,
browser_headers: Dict[str, str],
browser_headers: Dict[bytes, str],
browser_ignore_headers: SKIP_HEADER_T,
job_id: Optional[str],
cookies_enabled: bool,
Expand Down Expand Up @@ -1003,7 +1003,7 @@ def _load_mw_skip_headers(crawler):
return mw_skip_headers


def _load_browser_headers(settings):
def _load_browser_headers(settings) -> Dict[bytes, str]:
browser_headers = settings.getdict(
"ZYTE_API_BROWSER_HEADERS",
{"Referer": "referer"},
Expand Down
4 changes: 2 additions & 2 deletions scrapy_zyte_api/_request_fingerprinter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

logger = getLogger(__name__)

Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, crawler):
crawler=crawler,
)
if self._has_poet and not isinstance(
self._fallback_request_fingerprinter, RequestFingerprinter
self._fallback_request_fingerprinter, cast(type, RequestFingerprinter)
):
logger.warning(
f"You have scrapy-poet installed, but your custom value "
Expand Down
18 changes: 12 additions & 6 deletions scrapy_zyte_api/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DummyResponse: # type: ignore[no-redef]
from scrapy.downloadermiddlewares.retry import get_retry_request
except ImportError: # pragma: no cover
# https://github.com/scrapy/scrapy/blob/b1fe97dc6c8509d58b29c61cf7801eeee1b409a9/scrapy/downloadermiddlewares/retry.py#L57-L142
def get_retry_request(
def get_retry_request( # type: ignore[misc]
request,
*,
spider,
Expand Down Expand Up @@ -125,7 +125,7 @@ def get_retry_request(
from scrapy.http.request import NO_CALLBACK
except ImportError:

def NO_CALLBACK(response):
def NO_CALLBACK(response): # type: ignore[misc]
pass # pragma: no cover


Expand Down Expand Up @@ -165,7 +165,7 @@ def _get_asyncio_event_loop():
return set_asyncio_event_loop()

# https://github.com/scrapy/scrapy/blob/b1fe97dc6c8509d58b29c61cf7801eeee1b409a9/scrapy/utils/defer.py#L360-L379
def deferred_to_future(d):
def deferred_to_future(d): # type: ignore[misc]
return d.asFuture(_get_asyncio_event_loop())


Expand All @@ -177,7 +177,7 @@ def deferred_to_future(d):
def build_from_crawler(
objcls: Type[T], crawler: Crawler, /, *args: Any, **kwargs: Any
) -> T:
return create_instance(objcls, settings=None, crawler=crawler, *args, **kwargs)
return create_instance(objcls, settings=None, crawler=crawler, *args, **kwargs) # type: ignore[misc]


class PoolError(ValueError):
Expand Down Expand Up @@ -382,7 +382,7 @@ def check(self, response: Response, request: Request) -> bool:
location = self.location(request)
if not location:
return True
for action in response.raw_api_response.get("actions", []):
for action in response.raw_api_response.get("actions", []): # type: ignore[attr-defined]
if action.get("action", None) != "setLocation":
continue
if action.get("error", "").startswith("Action setLocation not supported "):
Expand Down Expand Up @@ -647,6 +647,8 @@ def _get_pool(self, request):
return pool

async def _init_session(self, session_id: str, request: Request, pool: str) -> bool:
assert self._crawler.engine
assert self._crawler.stats
session_config = self._get_session_config(request)
if meta_params := request.meta.get("zyte_api_session_params", None):
session_params = meta_params
Expand Down Expand Up @@ -685,7 +687,7 @@ async def _init_session(self, session_id: str, request: Request, pool: str) -> b
callback=NO_CALLBACK,
)
if _DOWNLOAD_NEEDS_SPIDER:
deferred = self._crawler.engine.download(
deferred = self._crawler.engine.download( # type: ignore[call-arg]
session_init_request, spider=spider
)
else:
Expand Down Expand Up @@ -829,6 +831,7 @@ async def check(self, response: Response, request: Request) -> bool:
"""Check the response for signs of session expiration, update the
internal session pool accordingly, and return ``False`` if the session
has expired or ``True`` if the session passed validation."""
assert self._crawler.stats
with self._fatal_error_handler:
if self.is_init_request(request):
return True
Expand Down Expand Up @@ -860,6 +863,7 @@ async def check(self, response: Response, request: Request) -> bool:

async def assign(self, request: Request):
"""Assign a working session to *request*."""
assert self._crawler.stats
with self._fatal_error_handler:
if self.is_init_request(request):
return
Expand Down Expand Up @@ -895,6 +899,7 @@ def is_enabled(self, request: Request) -> bool:
return session_config.enabled(request)

def handle_error(self, request: Request):
assert self._crawler.stats
with self._fatal_error_handler:
pool = self._get_pool(request)
self._crawler.stats.inc_value(
Expand All @@ -908,6 +913,7 @@ def handle_error(self, request: Request):
self._start_request_session_refresh(request, pool)

def handle_expiration(self, request: Request):
assert self._crawler.stats
with self._fatal_error_handler:
pool = self._get_pool(request)
self._crawler.stats.inc_value(
Expand Down
6 changes: 4 additions & 2 deletions scrapy_zyte_api/addon.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

from scrapy.settings import BaseSettings
from scrapy.utils.misc import load_object
from zyte_api import zyte_api_retrying
Expand Down Expand Up @@ -74,7 +76,7 @@ def update_settings(self, settings: BaseSettings) -> None:
settings.set(
"REQUEST_FINGERPRINTER_CLASS",
"scrapy_zyte_api.ScrapyZyteAPIRequestFingerprinter",
settings.getpriority("REQUEST_FINGERPRINTER_CLASS"),
cast(int, settings.getpriority("REQUEST_FINGERPRINTER_CLASS")),
)
else:
settings.set(
Expand Down Expand Up @@ -124,5 +126,5 @@ def update_settings(self, settings: BaseSettings) -> None:
settings.set(
"ZYTE_API_RETRY_POLICY",
_SESSION_RETRY_POLICIES.get(loaded_retry_policy, retry_policy),
settings.getpriority("ZYTE_API_RETRY_POLICY"),
cast(int, settings.getpriority("ZYTE_API_RETRY_POLICY")),
)
9 changes: 7 additions & 2 deletions scrapy_zyte_api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(
# We keep the client in the crawler object to prevent multiple,
# duplicate clients with the same settings to be used.
# https://github.com/scrapy-plugins/scrapy-zyte-api/issues/58
crawler.zyte_api_client = client
self._client: AsyncZyteAPI = crawler.zyte_api_client
crawler.zyte_api_client = client # type: ignore[attr-defined]
self._client: AsyncZyteAPI = crawler.zyte_api_client # type: ignore[attr-defined]
logger.info("Using a Zyte API key starting with %r", self._client.api_key[:7])
verify_installed_reactor(
"twisted.internet.asyncioreactor.AsyncioSelectorReactor"
Expand All @@ -104,6 +104,7 @@ def __init__(
)
self._param_parser = _ParamParser(crawler)
self._retry_policy = _load_retry_policy(settings)
assert crawler.stats
self._stats = crawler.stats
self._must_log_request = settings.getbool("ZYTE_API_LOG_REQUESTS", False)
self._truncate_limit = settings.getint("ZYTE_API_LOG_REQUESTS_TRUNCATE", 64)
Expand All @@ -129,6 +130,7 @@ async def engine_started(self):
self._session = self._client.session(trust_env=self._trust_env)
if not self._cookies_enabled:
return
assert self._crawler.engine
for middleware in self._crawler.engine.downloader.middleware.middlewares:
if isinstance(middleware, self._cookie_mw_cls):
self._cookie_jars = middleware.jars
Expand Down Expand Up @@ -275,6 +277,9 @@ def _process_request_error(self, request, error):
f"type={error.parsed.type!r}, request_id={error.request_id!r}) "
f"while processing URL ({request.url}): {detail}"
)
assert self._crawler
assert self._crawler.engine
assert self._crawler.spider
for status, error_type, close_reason in (
(401, "/auth/key-not-found", "zyte_api_bad_key"),
(403, "/auth/account-suspended", "zyte_api_suspended_account"),
Expand Down
7 changes: 5 additions & 2 deletions scrapy_zyte_api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scrapy.crawler import Crawler
from scrapy.utils.defer import maybe_deferred_to_future
from scrapy_poet import PageObjectInputProvider
from twisted.internet.defer import Deferred
from web_poet import (
AnyResponse,
BrowserHtml,
Expand Down Expand Up @@ -47,7 +48,7 @@
# requires Scrapy >= 2.8
from scrapy.http.request import NO_CALLBACK
except ImportError:
NO_CALLBACK = None
NO_CALLBACK = None # type: ignore[assignment]


_ITEM_KEYWORDS: Dict[type, str] = {
Expand Down Expand Up @@ -103,6 +104,7 @@ def is_provided(self, type_: Callable) -> bool:
return super().is_provided(strip_annotated(type_))

def _track_auto_fields(self, crawler: Crawler, request: Request, cls: Type):
assert crawler.stats
if cls not in _ITEM_KEYWORDS:
return
if self._should_track_auto_fields is None:
Expand Down Expand Up @@ -256,8 +258,9 @@ async def __call__( # noqa: C901
},
callback=NO_CALLBACK,
)
assert crawler.engine
api_response: ZyteAPITextResponse = await maybe_deferred_to_future(
crawler.engine.download(api_request)
cast("Deferred[ZyteAPITextResponse]", crawler.engine.download(api_request))
)

assert api_response.raw_api_response
Expand Down
20 changes: 14 additions & 6 deletions scrapy_zyte_api/responses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from base64 import b64decode
from copy import copy
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from scrapy import Request
from scrapy.http import HtmlResponse, Response, TextResponse
from scrapy.http import Headers, HtmlResponse, Response, TextResponse
from scrapy.http.cookies import CookieJar
from scrapy.responsetypes import responsetypes

Expand Down Expand Up @@ -113,7 +113,9 @@ def _prepare_headers(cls, api_response: Dict[str, Any]):

class ZyteAPITextResponse(ZyteAPIMixin, HtmlResponse):
@classmethod
def from_api_response(cls, api_response: Dict, *, request: Request = None):
def from_api_response(
cls, api_response: Dict, *, request: Optional[Request] = None
):
"""Alternative constructor to instantiate the response from the raw
Zyte API response.
"""
Expand Down Expand Up @@ -144,7 +146,9 @@ def replace(self, *args, **kwargs):

class ZyteAPIResponse(ZyteAPIMixin, Response):
@classmethod
def from_api_response(cls, api_response: Dict, *, request: Request = None):
def from_api_response(
cls, api_response: Dict, *, request: Optional[Request] = None
):
"""Alternative constructor to instantiate the response from the raw
Zyte API response.
"""
Expand Down Expand Up @@ -190,9 +194,13 @@ def _process_response(
return ZyteAPITextResponse.from_api_response(api_response, request=request)

if api_response.get("httpResponseHeaders") and api_response.get("httpResponseBody"):
# a plain dict here doesn't work correctly on Scrapy < 2.1
scrapy_headers = Headers()
for header in cast(List[Dict[str, str]], api_response["httpResponseHeaders"]):
scrapy_headers[header["name"].encode()] = header["value"].encode()
response_cls = responsetypes.from_args(
headers=api_response["httpResponseHeaders"],
url=api_response["url"],
headers=scrapy_headers,
url=cast(str, api_response["url"]),
# FIXME: update this when python-zyte-api supports base64 decoding
body=b64decode(api_response["httpResponseBody"]), # type: ignore
)
Expand Down
11 changes: 8 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from contextlib import asynccontextmanager, contextmanager
from os import environ
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Optional

from packaging.version import Version
from scrapy import Spider
from scrapy import __version__ as SCRAPY_VERSION
from scrapy.crawler import Crawler
from scrapy.utils.misc import load_object
from scrapy.utils.test import get_crawler as _get_crawler
Expand All @@ -14,7 +16,7 @@
_API_KEY = "a"

DEFAULT_CLIENT_CONCURRENCY = AsyncClient(api_key=_API_KEY).n_conn
SETTINGS_T = Dict[Union[Type, str], Any]
SETTINGS_T = Dict[str, Any]
SETTINGS: SETTINGS_T = {
"DOWNLOAD_HANDLERS": {
"http": "scrapy_zyte_api.handler.ScrapyZyteAPIDownloadHandler",
Expand All @@ -25,13 +27,16 @@
"scrapy_zyte_api.ScrapyZyteAPISessionDownloaderMiddleware": 667,
},
"REQUEST_FINGERPRINTER_CLASS": "scrapy_zyte_api.ScrapyZyteAPIRequestFingerprinter",
"REQUEST_FINGERPRINTER_IMPLEMENTATION": "2.7", # Silence deprecation warning
"SPIDER_MIDDLEWARES": {
"scrapy_zyte_api.ScrapyZyteAPISpiderMiddleware": 100,
},
"ZYTE_API_KEY": _API_KEY,
"TWISTED_REACTOR": "twisted.internet.asyncioreactor.AsyncioSelectorReactor",
}
if Version(SCRAPY_VERSION) < Version("2.12"):
SETTINGS["REQUEST_FINGERPRINTER_IMPLEMENTATION"] = (
"2.7" # Silence deprecation warning
)
try:
import scrapy_poet # noqa: F401
except ImportError:
Expand Down
Loading

0 comments on commit 5fbf0bc

Please sign in to comment.