diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1938bb646..34f02b546 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,9 @@ repos: args: [ --fix ] # Run the formatter. - id: ruff-format +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.0 + hooks: + - id: mypy + args: [--check-untyped-defs] + exclude: 'tests/|examples/|docs/' diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9792f7961..a978fcd64 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,33 @@ +1.1.0 (2024-09-20) +==================== + +**Added** +- Support for Post-Quantum KX Kyber768 (NIST Round 3) with X25519. +- Backport "QUIC Version 2". + "Rework packet encoding to support different protocol versions" https://github.com/aiortc/aioquic/commit/bd3497cce9aa906c47d5b7216752f55beed3d9d3 + "Add encryption for QUIC v2" https://github.com/aiortc/aioquic/commit/abf51897bb67f459921e4c26c8b3ea445aa79832 + "Refactor retry / version negotiation handling" https://github.com/aiortc/aioquic/commit/70dd040893d7d8af5a2a92361c1e844ebf867abb + "Add support for version_information transport parameter" https://github.com/aiortc/aioquic/commit/a59d9ad0b1df423376bf8b30ebb7642861fef54e + "Check Chosen Version matches the version in use by the connection" https://github.com/aiortc/aioquic/commit/a59d9ad0b1df423376bf8b30ebb7642861fef54e + +**Changed** +- Insert GREASE in KX, TLS Version and Ciphers. +- Backport "Only buffer up to 512 KiB of pending CRYPTO frames" https://github.com/aiortc/aioquic/commit/174a2ebbe928686ef9663acc663b3ac06c2d56f2 +- Backport "Improved path challenge handling" https://github.com/aiortc/aioquic/commit/b507364ea51f3e654decd143cc99f7001b5b7923 +- Backport "Limit the number of pending connection IDs marked for retirement." https://github.com/aiortc/aioquic/commit/4f73f18a23c22f48ef43cb3629b0686757f096af +- Backport "During address validation, count the entire received datagram" https://github.com/aiortc/aioquic/commit/afe5525822f71e277e534b08f198ec8724a7ad59 +- Update aws-lc-rs v1.8.1 to v1.9.0 +- Default supported signature algorithms to: ``ECDSA_SECP256R1_SHA256, RSA_PSS_RSAE_SHA256, RSA_PKCS1_SHA256, ECDSA_SECP384R1_SHA384, RSA_PSS_RSAE_SHA384, RSA_PKCS1_SHA384, RSA_PSS_RSAE_SHA512, RSA_PKCS1_SHA512, ED25519``. + +**Fixed** +- Certificate fingerprint matching. +- Backport upstream urllib3/urllib3#3434: util/ssl: make code (certificate fingerprint matching) resilient to missing hash functions. + In certain environments such as in a FIPS enabled system, certain algorithms such as md5 may be unavailable. + +**Misc** +- Backport "Use is for type comparisons" https://github.com/aiortc/aioquic/commit/5c55e0c75d414ab171a09a732c2d8aaf6f178c05 +- Postpone annotations parsing with ``from __future__ import annotations`` everywhere in order to simplify type annotations. + 1.0.9 (2024-08-17) ==================== diff --git a/Cargo.lock b/Cargo.lock index 930cca3ab..32755b7e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,9 +93,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae74d9bd0a7530e8afd1770739ad34b36838829d6ad61818f9230f683f5ad77" +checksum = "2f95446d919226d587817a7d21379e6eb099b97b45110a7f272a444ca5c54070" dependencies = [ "aws-lc-fips-sys", "aws-lc-sys", @@ -106,9 +106,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.20.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f0e249228c6ad2d240c2dc94b714d711629d52bad946075d8e9b2f5391f0703" +checksum = "b3ddc4a5b231dd6958b140ff3151b6412b3f4321fab354f399eec8f14b06df62" dependencies = [ "bindgen 0.69.4", "cc", @@ -1013,7 +1013,7 @@ dependencies = [ [[package]] name = "qh3" -version = "1.0.9" +version = "1.1.0" dependencies = [ "aws-lc-rs", "chacha20poly1305", diff --git a/Cargo.toml b/Cargo.toml index 4a5fa683d..4cc3b476b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "qh3" -version = "1.0.9" +version = "1.1.0" edition = "2021" rust-version = "1.75" license = "BSD-3" @@ -25,7 +25,7 @@ chacha20poly1305 = "0.10.1" pkcs8 = { version = "0.10.2", features = ["encryption", "pem"] } pkcs1 = { version = "0.7.5", features = ["pem"] } rustls-pemfile = "2.1.2" -aws-lc-rs = { version = "1.8.1", features=["bindgen"], default-features = false } +aws-lc-rs = { version = "1.9.0", features=["bindgen", "unstable"], default-features = false } x509-ocsp = { version = "0.2.1", features = ["builder"] } x509-cert = "0.2.5" der = "0.7.9" diff --git a/README.rst b/README.rst index 4cb7bcbe1..f5a3cc273 100644 --- a/README.rst +++ b/README.rst @@ -1,33 +1,34 @@ qh3 === -|pypi-pyversions| - -.. |pypi-v| image:: https://img.shields.io/pypi/v/qh3.svg - :target: https://pypi.python.org/pypi/qh3 +|pypi-pyversions| |pypi-stats| .. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/qh3.svg :target: https://pypi.python.org/pypi/qh3 + :alt: Supported Interpreters +.. |pypi-stats| image:: https://img.shields.io/pypi/dm/qh3 + :target: https://pypistats.org/packages/qh3 + :alt: PyPI - Downloads What is ``qh3``? ---------------- ``qh3`` is a maintained fork of the ``aioquic`` library. -It is lighter, and a bit faster, and more adapted to a broader audience as this package has no external dependency +It is lighter, faster, and more adapted to a broader audience as this package has no external dependency and does not rely on mainstream OpenSSL. While it is a compatible fork, it is not a drop-in replacement since the first major. See the CHANGELOG for details. -Regularly improved and expect a better time to initial response in issues and PRs. - ``qh3`` is a library for the QUIC network protocol in Python. It features a minimal TLS 1.3 implementation, a QUIC stack, and an HTTP/3 stack. QUIC was standardized in `RFC 9000`_ and HTTP/3 in `RFC 9114`_. ``qh3`` follow the standardized version of QUIC and HTTP/3. +QUIC stack conforming with `RFC 9000`_ (QUIC v1) and `RFC 9369`_ (QUIC v2) + To learn more about ``qh3`` please `read the documentation`_. ``qh3`` stands for **Q** UIC . **H** TTP/ **3**. @@ -65,6 +66,7 @@ Features - logging TLS traffic secrets - logging QUIC events in QLOG format - HTTP/3 server push support +- Post-Quantum (KEM) Key-Exchange (Kyber R3 NIST) Requirements ------------ @@ -84,11 +86,10 @@ License ``qh3`` is released under the `BSD license`_. .. _read the documentation: https://qh3.readthedocs.io/en/latest/ -.. _QUIC implementations: https://github.com/quicwg/base-drafts/wiki/Implementations -.. _cryptography: https://cryptography.io/ .. _BSD license: https://qh3.readthedocs.io/en/latest/license.html .. _RFC 8446: https://datatracker.ietf.org/doc/html/rfc8446 .. _RFC 9000: https://datatracker.ietf.org/doc/html/rfc9000 .. _RFC 9114: https://datatracker.ietf.org/doc/html/rfc9114 +.. _RFC 9369: https://datatracker.ietf.org/doc/html/rfc9369 .. _niquests: https://github.com/jawah/niquests .. _urllib3.future: https://github.com/jawah/urllib3.future diff --git a/examples/demo.py b/examples/demo.py index d94867760..e7a35236f 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -1,6 +1,7 @@ # # demo application for http3_server.py # +from __future__ import annotations import datetime import os diff --git a/examples/doq_client.py b/examples/doq_client.py index 7ebbca777..bf25adb83 100644 --- a/examples/doq_client.py +++ b/examples/doq_client.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import argparse import asyncio import logging import pickle import ssl import struct -from typing import Optional, cast +from typing import cast from dnslib.dns import QTYPE, DNSHeader, DNSQuestion, DNSRecord @@ -20,7 +22,7 @@ class DnsClientProtocol(QuicConnectionProtocol): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._ack_waiter: Optional[asyncio.Future[DNSRecord]] = None + self._ack_waiter: asyncio.Future[DNSRecord] | None = None async def query(self, query_name: str, query_type: str) -> None: # serialize query diff --git a/examples/doq_server.py b/examples/doq_server.py index 8a7bd0704..5b7acab2d 100644 --- a/examples/doq_server.py +++ b/examples/doq_server.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import argparse import asyncio import logging import struct -from typing import Dict, Optional from dnslib.dns import DNSRecord @@ -34,12 +35,12 @@ class SessionTicketStore: """ def __init__(self) -> None: - self.tickets: Dict[bytes, SessionTicket] = {} + self.tickets: dict[bytes, SessionTicket] = {} def add(self, ticket: SessionTicket) -> None: self.tickets[ticket.ticket] = ticket - def pop(self, label: bytes) -> Optional[SessionTicket]: + def pop(self, label: bytes) -> SessionTicket | None: return self.tickets.pop(label, None) diff --git a/examples/http3_client.py b/examples/http3_client.py index 843b99edd..6d38662ea 100644 --- a/examples/http3_client.py +++ b/examples/http3_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import logging @@ -6,7 +8,7 @@ import ssl import time from collections import deque -from typing import BinaryIO, Callable, Deque, Dict, List, Optional, cast +from typing import BinaryIO, Callable, Deque, cast from urllib.parse import urlparse import wsproto @@ -44,7 +46,7 @@ def __init__( method: str, url: URL, content: bytes = b"", - headers: Optional[Dict] = None, + headers: dict | None = None, ) -> None: if headers is None: headers = {} @@ -62,7 +64,7 @@ def __init__( self.http = http self.queue: asyncio.Queue[str] = asyncio.Queue() self.stream_id = stream_id - self.subprotocol: Optional[str] = None + self.subprotocol: str | None = None self.transmit = transmit self.websocket = wsproto.Connection(wsproto.ConnectionType.CLIENT) @@ -112,14 +114,14 @@ class HttpClient(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.pushes: Dict[int, Deque[H3Event]] = {} - self._http: Optional[H3Connection] = None - self._request_events: Dict[int, Deque[H3Event]] = {} - self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} - self._websockets: Dict[int, WebSocket] = {} + self.pushes: dict[int, Deque[H3Event]] = {} + self._http: H3Connection | None = None + self._request_events: dict[int, Deque[H3Event]] = {} + self._request_waiter: dict[int, asyncio.Future[Deque[H3Event]]] = {} + self._websockets: dict[int, WebSocket] = {} self._http = H3Connection(self._quic) - async def get(self, url: str, headers: Optional[Dict] = None) -> Deque[H3Event]: + async def get(self, url: str, headers: dict | None = None) -> Deque[H3Event]: """ Perform a GET request. """ @@ -128,7 +130,7 @@ async def get(self, url: str, headers: Optional[Dict] = None) -> Deque[H3Event]: ) async def post( - self, url: str, data: bytes, headers: Optional[Dict] = None + self, url: str, data: bytes, headers: dict | None = None ) -> Deque[H3Event]: """ Perform a POST request. @@ -138,7 +140,7 @@ async def post( ) async def websocket( - self, url: str, subprotocols: Optional[List[str]] = None + self, url: str, subprotocols: list[str] | None = None ) -> WebSocket: """ Open a WebSocket. @@ -229,9 +231,9 @@ async def _request(self, request: HttpRequest) -> Deque[H3Event]: async def perform_http_request( client: HttpClient, url: str, - data: Optional[str], + data: str | None, include: bool, - output_dir: Optional[str], + output_dir: str | None, ) -> None: # perform request start = time.time() @@ -278,7 +280,7 @@ async def perform_http_request( def process_http_pushes( client: HttpClient, include: bool, - output_dir: Optional[str], + output_dir: str | None, ) -> None: for _, http_events in client.pushes.items(): method = "" @@ -333,10 +335,10 @@ def save_session_ticket(ticket: SessionTicket) -> None: async def main( configuration: QuicConfiguration, - urls: List[str], - data: Optional[str], + urls: list[str], + data: str | None, include: bool, - output_dir: Optional[str], + output_dir: str | None, local_port: int, zero_rtt: bool, ) -> None: diff --git a/examples/http3_server.py b/examples/http3_server.py index 2025d2488..4b67ca16b 100644 --- a/examples/http3_server.py +++ b/examples/http3_server.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import importlib @@ -5,7 +7,7 @@ import time from collections import deque from email.utils import formatdate -from typing import Callable, Deque, Dict, List, Optional, Union, cast +from typing import Callable, Deque, Union, cast import wsproto import wsproto.events @@ -43,7 +45,7 @@ def __init__( authority: bytes, connection: H3Connection, protocol: QuicConnectionProtocol, - scope: Dict, + scope: dict, stream_ended: bool, stream_id: int, transmit: Callable[[], None], @@ -51,7 +53,7 @@ def __init__( self.authority = authority self.connection = connection self.protocol = protocol - self.queue: asyncio.Queue[Dict] = asyncio.Queue() + self.queue: asyncio.Queue[dict] = asyncio.Queue() self.scope = scope self.stream_id = stream_id self.transmit = transmit @@ -76,10 +78,10 @@ def http_event_received(self, event: H3Event) -> None: async def run_asgi(self, app: AsgiApplication) -> None: await app(self.scope, self.receive, self.send) - async def receive(self) -> Dict: + async def receive(self) -> dict: return await self.queue.get() - async def send(self, message: Dict) -> None: + async def send(self, message: dict) -> None: if message["type"] == "http.response.start": self.connection.send_headers( stream_id=self.stream_id, @@ -128,18 +130,18 @@ def __init__( self, *, connection: H3Connection, - scope: Dict, + scope: dict, stream_id: int, transmit: Callable[[], None], ) -> None: self.closed = False self.connection = connection self.http_event_queue: Deque[DataReceived] = deque() - self.queue: asyncio.Queue[Dict] = asyncio.Queue() + self.queue: asyncio.Queue[dict] = asyncio.Queue() self.scope = scope self.stream_id = stream_id self.transmit = transmit - self.websocket: Optional[wsproto.Connection] = None + self.websocket: wsproto.Connection | None = None def http_event_received(self, event: H3Event) -> None: if isinstance(event, DataReceived) and not self.closed: @@ -170,10 +172,10 @@ async def run_asgi(self, app: AsgiApplication) -> None: if not self.closed: await self.send({"type": "websocket.close", "code": 1000}) - async def receive(self) -> Dict: + async def receive(self) -> dict: return await self.queue.get() - async def send(self, message: Dict) -> None: + async def send(self, message: dict) -> None: data = b"" end_stream = False if message["type"] == "websocket.accept": @@ -228,7 +230,7 @@ def __init__( self, *, connection: H3Connection, - scope: Dict, + scope: dict, stream_id: int, transmit: Callable[[], None], ) -> None: @@ -236,7 +238,7 @@ def __init__( self.closed = False self.connection = connection self.http_event_queue: Deque[DataReceived] = deque() - self.queue: asyncio.Queue[Dict] = asyncio.Queue() + self.queue: asyncio.Queue[dict] = asyncio.Queue() self.scope = scope self.stream_id = stream_id self.transmit = transmit @@ -273,10 +275,10 @@ async def run_asgi(self, app: AsgiApplication) -> None: if not self.closed: await self.send({"type": "webtransport.close"}) - async def receive(self) -> Dict: + async def receive(self) -> dict: return await self.queue.get() - async def send(self, message: Dict) -> None: + async def send(self, message: dict) -> None: data = b"" end_stream = False @@ -322,8 +324,8 @@ async def send(self, message: Dict) -> None: class HttpServerProtocol(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self._handlers: Dict[int, Handler] = {} - self._http: Optional[H3Connection] = None + self._handlers: dict[int, Handler] = {} + self._http: H3Connection | None = None def http_event_received(self, event: H3Event) -> None: if isinstance(event, HeadersReceived) and event.stream_id not in self._handlers: @@ -358,9 +360,9 @@ def http_event_received(self, event: H3Event) -> None: client = (client_addr[0], client_addr[1]) handler: Handler - scope: Dict + scope: dict if method == "CONNECT" and protocol == "websocket": - subprotocols: List[str] = [] + subprotocols: list[str] = [] for header, value in event.headers: if header == b"sec-websocket-protocol": subprotocols = [x.strip() for x in value.decode().split(",")] @@ -403,7 +405,7 @@ def http_event_received(self, event: H3Event) -> None: transmit=self.transmit, ) else: - extensions: Dict[str, Dict] = {} + extensions: dict[str, dict] = {} if isinstance(self._http, H3Connection): extensions["http.response.push"] = {} scope = { @@ -463,12 +465,12 @@ class SessionTicketStore: """ def __init__(self) -> None: - self.tickets: Dict[bytes, SessionTicket] = {} + self.tickets: dict[bytes, SessionTicket] = {} def add(self, ticket: SessionTicket) -> None: self.tickets[ticket.ticket] = ticket - def pop(self, label: bytes) -> Optional[SessionTicket]: + def pop(self, label: bytes) -> SessionTicket | None: return self.tickets.pop(label, None) diff --git a/examples/siduck_client.py b/examples/siduck_client.py index 943393e36..a0ea37b10 100644 --- a/examples/siduck_client.py +++ b/examples/siduck_client.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import argparse import asyncio import logging import ssl -from typing import Optional, cast +from typing import cast from qh3.asyncio.client import connect from qh3.asyncio.protocol import QuicConnectionProtocol @@ -16,7 +18,7 @@ class SiduckClient(QuicConnectionProtocol): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._ack_waiter: Optional[asyncio.Future[None]] = None + self._ack_waiter: asyncio.Future[None] | None = None async def quack(self) -> None: assert self._ack_waiter is None, "Only one quack at a time." diff --git a/pyproject.toml b/pyproject.toml index 15e581be9..4b8eba1b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,3 +63,6 @@ select = [ "W", # pycodestyle "I", # isort ] + +[tool.ruff.isort] +required-imports = ["from __future__ import annotations"] diff --git a/qh3/__init__.py b/qh3/__init__.py index dc0474f81..f7af67461 100644 --- a/qh3/__init__.py +++ b/qh3/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from .asyncio import QuicConnectionProtocol, connect, serve @@ -11,7 +13,7 @@ from .quic.packet import QuicProtocolVersion from .tls import CipherSuite, SessionTicket -__version__ = "1.0.9" +__version__ = "1.1.0" __all__ = ( "connect", diff --git a/qh3/_hazmat.pyi b/qh3/_hazmat.pyi index 1786b7796..624ab37a6 100644 --- a/qh3/_hazmat.pyi +++ b/qh3/_hazmat.pyi @@ -122,6 +122,11 @@ def verify_with_public_key( public_key_raw: bytes, algorithm: int, message: bytes, signature: bytes ) -> None: ... +class X25519Kyber768Draft00KeyExchange: + def __init__(self) -> None: ... + def public_key(self) -> bytes: ... + def exchange(self, peer_public_key: bytes) -> bytes: ... + class X25519KeyExchange: def __init__(self) -> None: ... def public_key(self) -> bytes: ... diff --git a/qh3/asyncio/client.py b/qh3/asyncio/client.py index 4acb4f9df..953394acd 100644 --- a/qh3/asyncio/client.py +++ b/qh3/asyncio/client.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import ipaddress import socket from contextlib import asynccontextmanager -from typing import AsyncGenerator, Callable, Optional, cast +from typing import AsyncGenerator, Callable, cast from ..quic.configuration import QuicConfiguration from ..quic.connection import QuicConnection @@ -21,10 +23,10 @@ async def connect( host: str, port: int, *, - configuration: Optional[QuicConfiguration] = None, - create_protocol: Optional[Callable] = QuicConnectionProtocol, - session_ticket_handler: Optional[SessionTicketHandler] = None, - stream_handler: Optional[QuicStreamHandler] = None, + configuration: QuicConfiguration | None = None, + create_protocol: Callable | None = QuicConnectionProtocol, + session_ticket_handler: SessionTicketHandler | None = None, + stream_handler: QuicStreamHandler | None = None, wait_connected: bool = True, local_port: int = 0, ) -> AsyncGenerator[QuicConnectionProtocol, None]: diff --git a/qh3/asyncio/protocol.py b/qh3/asyncio/protocol.py index 0abf02ccd..b633410e0 100644 --- a/qh3/asyncio/protocol.py +++ b/qh3/asyncio/protocol.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import asyncio -from typing import Any, Callable, Dict, Optional, Tuple, Union, cast +from typing import Any, Callable, cast from ..quic import events from ..quic.connection import NetworkAddress, QuicConnection @@ -10,21 +12,21 @@ class QuicConnectionProtocol(asyncio.DatagramProtocol): def __init__( - self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None + self, quic: QuicConnection, stream_handler: QuicStreamHandler | None = None ): loop = asyncio.get_event_loop() self._closed = asyncio.Event() self._connected = False - self._connected_waiter: Optional[asyncio.Future[None]] = None + self._connected_waiter: asyncio.Future[None] | None = None self._loop = loop - self._ping_waiters: Dict[int, asyncio.Future[None]] = {} + self._ping_waiters: dict[int, asyncio.Future[None]] = {} self._quic = quic - self._stream_readers: Dict[int, asyncio.StreamReader] = {} - self._timer: Optional[asyncio.TimerHandle] = None - self._timer_at: Optional[float] = None - self._transmit_task: Optional[asyncio.Handle] = None - self._transport: Optional[asyncio.DatagramTransport] = None + self._stream_readers: dict[int, asyncio.StreamReader] = {} + self._timer: asyncio.TimerHandle | None = None + self._timer_at: float | None = None + self._transmit_task: asyncio.Handle | None = None + self._transport: asyncio.DatagramTransport | None = None # callbacks self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None @@ -62,7 +64,7 @@ def connect(self, addr: NetworkAddress) -> None: async def create_stream( self, is_unidirectional: bool = False - ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: """ Create a QUIC stream and return a pair of (reader, writer) objects. @@ -131,7 +133,7 @@ async def wait_connected(self) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: self._transport = cast(asyncio.DatagramTransport, transport) - def datagram_received(self, data: Union[bytes, str], addr: NetworkAddress) -> None: + def datagram_received(self, data: bytes | str, addr: NetworkAddress) -> None: self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time()) self._process_events() self.transmit() @@ -161,7 +163,7 @@ def quic_event_received(self, event: events.QuicEvent) -> None: def _create_stream( self, stream_id: int - ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: adapter = QuicStreamAdapter(self, stream_id) reader = asyncio.StreamReader() writer = asyncio.StreamWriter(adapter, None, reader, self._loop) diff --git a/qh3/asyncio/server.py b/qh3/asyncio/server.py index 6d71cd888..aaaf38295 100644 --- a/qh3/asyncio/server.py +++ b/qh3/asyncio/server.py @@ -1,13 +1,15 @@ +from __future__ import annotations + import asyncio import os from functools import partial -from typing import Callable, Dict, Optional, Union, cast +from typing import Callable, cast from ..buffer import Buffer from ..quic.configuration import QuicConfiguration from ..quic.connection import NetworkAddress, QuicConnection from ..quic.packet import ( - PACKET_TYPE_INITIAL, + QuicPacketType, encode_quic_retry, encode_quic_version_negotiation, pull_quic_header, @@ -25,18 +27,18 @@ def __init__( *, configuration: QuicConfiguration, create_protocol: Callable = QuicConnectionProtocol, - session_ticket_fetcher: Optional[SessionTicketFetcher] = None, - session_ticket_handler: Optional[SessionTicketHandler] = None, + session_ticket_fetcher: SessionTicketFetcher | None = None, + session_ticket_handler: SessionTicketHandler | None = None, retry: bool = False, - stream_handler: Optional[QuicStreamHandler] = None, + stream_handler: QuicStreamHandler | None = None, ) -> None: self._configuration = configuration self._create_protocol = create_protocol self._loop = asyncio.get_event_loop() - self._protocols: Dict[bytes, QuicConnectionProtocol] = {} + self._protocols: dict[bytes, QuicConnectionProtocol] = {} self._session_ticket_fetcher = session_ticket_fetcher self._session_ticket_handler = session_ticket_handler - self._transport: Optional[asyncio.DatagramTransport] = None + self._transport: asyncio.DatagramTransport | None = None self._stream_handler = stream_handler @@ -54,7 +56,7 @@ def close(self): def connection_made(self, transport: asyncio.BaseTransport) -> None: self._transport = cast(asyncio.DatagramTransport, transport) - def datagram_received(self, data: Union[bytes, str], addr: NetworkAddress) -> None: + def datagram_received(self, data: bytes | str, addr: NetworkAddress) -> None: data = cast(bytes, data) buf = Buffer(data=data) @@ -81,12 +83,12 @@ def datagram_received(self, data: Union[bytes, str], addr: NetworkAddress) -> No return protocol = self._protocols.get(header.destination_cid, None) - original_destination_connection_id: Optional[bytes] = None - retry_source_connection_id: Optional[bytes] = None + original_destination_connection_id: bytes | None = None + retry_source_connection_id: bytes | None = None if ( protocol is None and len(data) >= 1200 - and header.packet_type == PACKET_TYPE_INITIAL + and header.packet_type == QuicPacketType.INITIAL ): # retry if self._retry is not None: @@ -169,8 +171,8 @@ async def serve( *, configuration: QuicConfiguration, create_protocol: Callable = QuicConnectionProtocol, - session_ticket_fetcher: Optional[SessionTicketFetcher] = None, - session_ticket_handler: Optional[SessionTicketHandler] = None, + session_ticket_fetcher: SessionTicketFetcher | None = None, + session_ticket_handler: SessionTicketHandler | None = None, retry: bool = False, stream_handler: QuicStreamHandler = None, ) -> QuicServer: diff --git a/qh3/h3/connection.py b/qh3/h3/connection.py index fa0464a83..c128131d1 100644 --- a/qh3/h3/connection.py +++ b/qh3/h3/connection.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging import re from enum import Enum, IntEnum -from typing import Dict, FrozenSet, List, Optional, Set from .._hazmat import ( DecoderStreamError, @@ -159,7 +160,7 @@ def encode_frame(frame_type: int, frame_data: bytes) -> bytes: return buf.data -def encode_settings(settings: Dict[int, int]) -> bytes: +def encode_settings(settings: dict[int, int]) -> bytes: buf = Buffer(capacity=1024) for setting, value in settings.items(): buf.push_uint_var(setting) @@ -174,9 +175,9 @@ def parse_max_push_id(data: bytes) -> int: return max_push_id -def parse_settings(data: bytes) -> Dict[int, int]: +def parse_settings(data: bytes) -> dict[int, int]: buf = Buffer(data=data) - settings: Dict[int, int] = {} + settings: dict[int, int] = {} while not buf.eof(): setting = buf.pull_uint_var() value = buf.pull_uint_var() @@ -190,14 +191,14 @@ def parse_settings(data: bytes) -> Dict[int, int]: def validate_headers( headers: Headers, - allowed_pseudo_headers: FrozenSet[bytes], - required_pseudo_headers: FrozenSet[bytes], + allowed_pseudo_headers: frozenset[bytes], + required_pseudo_headers: frozenset[bytes], ) -> None: after_pseudo_headers = False - authority: Optional[bytes] = None - path: Optional[bytes] = None - scheme: Optional[bytes] = None - seen_pseudo_headers: Set[bytes] = set() + authority: bytes | None = None + path: bytes | None = None + scheme: bytes | None = None + seen_pseudo_headers: set[bytes] = set() for key, value in headers: if UPPERCASE.search(key): raise MessageError("Header %r contains uppercase letters" % key) @@ -280,17 +281,17 @@ def validate_trailers(headers: Headers) -> None: class H3Stream: def __init__(self, stream_id: int) -> None: self.blocked = False - self.blocked_frame_size: Optional[int] = None + self.blocked_frame_size: int | None = None self.buffer = b"" self.ended = False - self.frame_size: Optional[int] = None - self.frame_type: Optional[int] = None + self.frame_size: int | None = None + self.frame_type: int | None = None self.headers_recv_state: HeadersState = HeadersState.INITIAL self.headers_send_state: HeadersState = HeadersState.INITIAL - self.push_id: Optional[int] = None - self.session_id: Optional[int] = None + self.push_id: int | None = None + self.session_id: int | None = None self.stream_id = stream_id - self.stream_type: Optional[int] = None + self.stream_type: int | None = None class H3Connection: @@ -309,7 +310,7 @@ def __init__(self, quic: QuicConnection, enable_webtransport: bool = False) -> N self._is_client = quic.configuration.is_client self._is_done = False self._quic = quic - self._quic_logger: Optional[QuicLoggerTrace] = quic._quic_logger + self._quic_logger: QuicLoggerTrace | None = quic._quic_logger self._decoder = QpackDecoder(self._max_table_capacity, self._blocked_streams) self._decoder_bytes_received = 0 self._decoder_bytes_sent = 0 @@ -317,22 +318,22 @@ def __init__(self, quic: QuicConnection, enable_webtransport: bool = False) -> N self._encoder_bytes_received = 0 self._encoder_bytes_sent = 0 self._settings_received = False - self._stream: Dict[int, H3Stream] = {} + self._stream: dict[int, H3Stream] = {} - self._max_push_id: Optional[int] = 8 if self._is_client else None + self._max_push_id: int | None = 8 if self._is_client else None self._next_push_id: int = 0 - self._local_control_stream_id: Optional[int] = None - self._local_decoder_stream_id: Optional[int] = None - self._local_encoder_stream_id: Optional[int] = None + self._local_control_stream_id: int | None = None + self._local_decoder_stream_id: int | None = None + self._local_encoder_stream_id: int | None = None - self._peer_control_stream_id: Optional[int] = None - self._peer_decoder_stream_id: Optional[int] = None - self._peer_encoder_stream_id: Optional[int] = None - self._received_settings: Optional[Dict[int, int]] = None - self._sent_settings: Optional[Dict[int, int]] = None + self._peer_control_stream_id: int | None = None + self._peer_decoder_stream_id: int | None = None + self._peer_encoder_stream_id: int | None = None + self._received_settings: dict[int, int] | None = None + self._sent_settings: dict[int, int] | None = None - self._blocked_stream_map: Dict[int, H3Stream] = {} + self._blocked_stream_map: dict[int, H3Stream] = {} self._init_connection() @@ -360,7 +361,7 @@ def create_webtransport_stream( ) return stream_id - def handle_event(self, event: QuicEvent) -> List[H3Event]: + def handle_event(self, event: QuicEvent) -> list[H3Event]: """ Handle a QUIC event and return a list of HTTP events. @@ -501,22 +502,20 @@ def send_headers( ) @property - def received_settings(self) -> Optional[Dict[int, int]]: + def received_settings(self) -> dict[int, int] | None: """ Return the received SETTINGS frame, or None. """ return self._received_settings @property - def sent_settings(self) -> Optional[Dict[int, int]]: + def sent_settings(self) -> dict[int, int] | None: """ Return the sent SETTINGS frame, or None. """ return self._sent_settings - def _create_uni_stream( - self, stream_type: int, push_id: Optional[int] = None - ) -> int: + def _create_uni_stream(self, stream_type: int, push_id: int | None = None) -> int: """ Create an unidirectional stream of the given type. """ @@ -527,7 +526,7 @@ def _create_uni_stream( self._quic.send_stream_data(stream_id, encode_uint_var(stream_type)) return stream_id - def _decode_headers(self, stream_id: int, frame_data: Optional[bytes]) -> Headers: + def _decode_headers(self, stream_id: int, frame_data: bytes | None) -> Headers: """ Decode a HEADERS block and send decoder updates on the decoder stream. @@ -568,11 +567,11 @@ def _get_or_create_stream(self, stream_id: int) -> H3Stream: self._stream[stream_id] = H3Stream(stream_id) return self._stream[stream_id] - def _get_local_settings(self) -> Dict[int, int]: + def _get_local_settings(self) -> dict[int, int]: """ Return the local HTTP/3 settings. """ - settings: Dict[int, int] = { + settings: dict[int, int] = { Setting.QPACK_MAX_TABLE_CAPACITY: self._max_table_capacity, Setting.QPACK_BLOCKED_STREAMS: self._blocked_streams, Setting.ENABLE_CONNECT_PROTOCOL: 1, @@ -618,14 +617,14 @@ def _handle_control_frame(self, frame_type: int, frame_data: bytes) -> None: def _handle_request_or_push_frame( self, frame_type: int, - frame_data: Optional[bytes], + frame_data: bytes | None, stream: H3Stream, stream_ended: bool, - ) -> List[H3Event]: + ) -> list[H3Event]: """ Handle a frame received on a request or push stream. """ - http_events: List[H3Event] = [] + http_events: list[H3Event] = [] if frame_type == FrameType.DATA: # check DATA frame is allowed @@ -758,7 +757,7 @@ def _init_connection(self) -> None: ) def _log_stream_type( - self, stream_id: int, stream_type: int, push_id: Optional[int] = None + self, stream_id: int, stream_type: int, push_id: int | None = None ) -> None: if self._quic_logger is not None: type_name = { @@ -779,7 +778,7 @@ def _log_stream_type( data=data, ) - def _receive_datagram(self, data: bytes) -> List[H3Event]: + def _receive_datagram(self, data: bytes) -> list[H3Event]: """ Handle a datagram. """ @@ -792,11 +791,11 @@ def _receive_datagram(self, data: bytes) -> List[H3Event]: def _receive_request_or_push_data( self, stream: H3Stream, data: bytes, stream_ended: bool - ) -> List[H3Event]: + ) -> list[H3Event]: """ Handle data received on a request or push stream. """ - http_events: List[H3Event] = [] + http_events: list[H3Event] = [] stream.buffer += data if stream_ended: @@ -937,8 +936,8 @@ def _receive_request_or_push_data( def _receive_stream_data_uni( self, stream: H3Stream, data: bytes, stream_ended: bool - ) -> List[H3Event]: - http_events: List[H3Event] = [] + ) -> list[H3Event]: + http_events: list[H3Event] = [] stream.buffer += data if stream_ended: @@ -946,7 +945,7 @@ def _receive_stream_data_uni( buf = Buffer(data=stream.buffer) consumed = 0 - unblocked_streams: Set[int] = set() + unblocked_streams: set[int] = set() while ( stream.stream_type @@ -1061,7 +1060,7 @@ def _receive_stream_data_uni( for blocked_id, blocked_stream in self._blocked_stream_map.items(): try: stream_data, headers = self._decoder.resume_header(blocked_id) - blocked_stream._pending = ( + blocked_stream._pending = ( # type: ignore[attr-defined] stream_data, headers, ) @@ -1103,7 +1102,7 @@ def _receive_stream_data_uni( return http_events - def _validate_settings(self, settings: Dict[int, int]) -> None: + def _validate_settings(self, settings: dict[int, int]) -> None: for setting in [ Setting.ENABLE_CONNECT_PROTOCOL, Setting.ENABLE_WEBTRANSPORT, diff --git a/qh3/h3/events.py b/qh3/h3/events.py index 1f5a35dbb..451a40f6d 100644 --- a/qh3/h3/events.py +++ b/qh3/h3/events.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Tuple Headers = List[Tuple[bytes, bytes]] @@ -26,7 +28,7 @@ class DataReceived(H3Event): stream_ended: bool "Whether the STREAM frame had the FIN bit set." - push_id: Optional[int] = None + push_id: int | None = None "The Push ID or `None` if this is not a push." @@ -59,7 +61,7 @@ class HeadersReceived(H3Event): stream_ended: bool "Whether the STREAM frame had the FIN bit set." - push_id: Optional[int] = None + push_id: int | None = None "The Push ID or `None` if this is not a push." diff --git a/qh3/h3/exceptions.py b/qh3/h3/exceptions.py index d5200c1d6..c14691929 100644 --- a/qh3/h3/exceptions.py +++ b/qh3/h3/exceptions.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + class H3Error(Exception): """ Base class for HTTP/3 exceptions. diff --git a/qh3/quic/configuration.py b/qh3/quic/configuration.py index 2e8613e5a..fee7d03fb 100644 --- a/qh3/quic/configuration.py +++ b/qh3/quic/configuration.py @@ -96,6 +96,7 @@ class QuicConfiguration: initial_rtt: float = 0.1 max_datagram_frame_size: int | None = None + original_version: int | None = None private_key: ( EcPrivateKey | Ed25519PrivateKey | DsaPrivateKey | RsaPrivateKey | None @@ -105,6 +106,7 @@ class QuicConfiguration: supported_versions: list[int] = field( default_factory=lambda: [ QuicProtocolVersion.VERSION_1, + QuicProtocolVersion.VERSION_2, ] ) verify_mode: int | None = None @@ -120,10 +122,15 @@ def load_cert_chain( """ if isinstance(certfile, str): - certfile = certfile.encode("ascii") + certfile = certfile.encode() + elif isinstance(certfile, PathLike): + certfile = str(certfile).encode() - if keyfile is not None and isinstance(keyfile, str): - keyfile = keyfile.encode("ascii") + if keyfile is not None: + if isinstance(keyfile, str): + keyfile = keyfile.encode() + elif isinstance(keyfile, PathLike): + keyfile = str(keyfile).encode() # we either have the certificate or a file path in certfile/keyfile. if b"-----BEGIN" not in certfile: diff --git a/qh3/quic/connection.py b/qh3/quic/connection.py index 496e4935d..90a58cca0 100644 --- a/qh3/quic/connection.py +++ b/qh3/quic/connection.py @@ -23,26 +23,24 @@ size_uint_var, ) from . import events -from .crypto import CryptoError, CryptoPair, KeyUnavailableError +from .crypto import CryptoError, CryptoPair, KeyUnavailableError, NoCallback from .packet import ( CONNECTION_ID_MAX_SIZE, NON_ACK_ELICITING_FRAME_TYPES, - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_ONE_RTT, - PACKET_TYPE_RETRY, - PACKET_TYPE_ZERO_RTT, PROBING_FRAME_TYPES, RETRY_INTEGRITY_TAG_SIZE, STATELESS_RESET_TOKEN_SIZE, QuicErrorCode, QuicFrameType, + QuicHeader, + QuicPacketType, QuicProtocolVersion, QuicStreamFrame, QuicTransportParameters, + QuicVersionInformation, get_retry_integrity_tag, get_spin_bit, - is_long_header, + pretty_protocol_version, pull_ack_frame, pull_quic_header, pull_quic_transport_parameters, @@ -68,6 +66,9 @@ "1": tls.Epoch.ONE_RTT, } MAX_EARLY_DATA = 0xFFFFFFFF +MAX_REMOTE_CHALLENGES = 5 +MAX_LOCAL_CHALLENGES = 5 +MAX_PENDING_RETIRES = 100 SECRETS_LABELS = [ [ None, @@ -85,6 +86,7 @@ STREAM_FLAGS = 0x07 STREAM_COUNT_MAX = 0x1000000000000000 UDP_HEADER_SIZE = 8 +MAX_PENDING_CRYPTO = 524288 # in bytes NetworkAddress = Any @@ -111,16 +113,29 @@ def EPOCHS(shortcut: str) -> frozenset[tls.Epoch]: return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut) +def is_version_compatible(from_version: int, to_version: int) -> bool: + """ + Return whether it is possible to perform compatible version negotiation + from `from_version` to `to_version`. + """ + # Version 1 is compatible with version 2 and vice versa. These are the + # only compatible versions so far. + return {from_version, to_version} == { + QuicProtocolVersion.VERSION_1, + QuicProtocolVersion.VERSION_2, + } + + def dump_cid(cid: bytes) -> str: return binascii.hexlify(cid).decode("ascii") -def get_epoch(packet_type: int) -> tls.Epoch: - if packet_type == PACKET_TYPE_INITIAL: +def get_epoch(packet_type: QuicPacketType) -> tls.Epoch: + if packet_type == QuicPacketType.INITIAL: return tls.Epoch.INITIAL - elif packet_type == PACKET_TYPE_ZERO_RTT: + elif packet_type == QuicPacketType.ZERO_RTT: return tls.Epoch.ZERO_RTT - elif packet_type == PACKET_TYPE_HANDSHAKE: + elif packet_type == QuicPacketType.HANDSHAKE: return tls.Epoch.HANDSHAKE else: return tls.Epoch.ONE_RTT @@ -183,14 +198,14 @@ class QuicConnectionState(Enum): TERMINATED = 4 -@dataclass class QuicNetworkPath: - addr: NetworkAddress - bytes_received: int = 0 - bytes_sent: int = 0 - is_validated: bool = False - local_challenge: bytes | None = None - remote_challenge: bytes | None = None + def __init__(self, addr: NetworkAddress, is_validated: bool = False): + self.addr: NetworkAddress = addr + self.bytes_received: int = 0 + self.bytes_sent: int = 0 + self.is_validated: bool = is_validated + self.local_challenge_sent: bool = False + self.remote_challenges: Deque[bytes] = deque() def can_send(self, size: int) -> bool: return self.is_validated or (self.bytes_sent + size) <= 3 * self.bytes_received @@ -203,6 +218,7 @@ class QuicReceiveContext: network_path: QuicNetworkPath quic_logger_frames: list[Any] | None time: float + version: int | None END_STATES = frozenset( @@ -265,7 +281,10 @@ def __init__( self._close_event: events.ConnectionTerminated | None = None self._connect_called = False self._cryptos: dict[tls.Epoch, CryptoPair] = {} + self._cryptos_initial: dict[int, CryptoPair] = {} self._crypto_buffers: dict[tls.Epoch, Buffer] = {} + self._crypto_frame_type: int | None = None + self._crypto_packet_version: int | None = None self._crypto_retransmitted = False self._crypto_streams: dict[tls.Epoch, QuicStream] = {} self._events: Deque[events.QuicEvent] = deque() @@ -283,6 +302,7 @@ def __init__( self._host_cid_seq = 1 self._local_ack_delay_exponent = 3 self._local_active_connection_id_limit = 8 + self._local_challenges: dict[bytes, QuicNetworkPath] = {} self._local_initial_source_connection_id = self._host_cids[0].cid self._local_max_data = Limit( frame_type=QuicFrameType.MAX_DATA, @@ -306,12 +326,12 @@ def __init__( self._network_paths: list[QuicNetworkPath] = [] self._pacing_at: float | None = None self._packet_number = 0 - self._parameters_received = False self._peer_cid = QuicConnectionId( cid=os.urandom(configuration.connection_id_length), sequence_number=None ) self._peer_cid_available: list[QuicConnectionId] = [] self._peer_cid_sequence_numbers: set[int] = {0} + self._peer_retire_prior_to = 0 self._peer_token = b"" self._quic_logger: QuicLoggerTrace | None = None self._remote_ack_delay_exponent = 3 @@ -326,6 +346,7 @@ def __init__( self._remote_max_stream_data_uni = 0 self._remote_max_streams_bidi = 0 self._remote_max_streams_uni = 0 + self._remote_version_information: QuicVersionInformation | None = None self._retry_count = 0 self._retry_source_connection_id = retry_source_connection_id self._spaces: dict[tls.Epoch, QuicPacketSpace] = {} @@ -338,7 +359,8 @@ def __init__( self._streams_blocked_uni: list[QuicStream] = [] self._streams_finished: set[int] = set() self._version: int | None = None - self._version_negotiation_count = 0 + self._version_negotiated_compatible = False + self._version_negotiated_incompatible = False if self._is_client: self._original_destination_connection_id = self._peer_cid.cid @@ -500,7 +522,10 @@ def connect(self, addr: NetworkAddress, now: float) -> None: self._connect_called = True self._network_paths = [QuicNetworkPath(addr, is_validated=True)] - self._version = self._configuration.supported_versions[0] + if self._configuration.original_version is not None: + self._version = self._configuration.original_version + else: + self._version = self._configuration.supported_versions[0] self._connect(now=now) def datagrams_to_send(self, now: float) -> list[tuple[bytes, NetworkAddress]]: @@ -533,10 +558,10 @@ def datagrams_to_send(self, now: float) -> list[tuple[bytes, NetworkAddress]]: epoch_packet_types = [] if not self._handshake_confirmed: epoch_packet_types += [ - (tls.Epoch.INITIAL, PACKET_TYPE_INITIAL), - (tls.Epoch.HANDSHAKE, PACKET_TYPE_HANDSHAKE), + (tls.Epoch.INITIAL, QuicPacketType.INITIAL), + (tls.Epoch.HANDSHAKE, QuicPacketType.HANDSHAKE), ] - epoch_packet_types.append((tls.Epoch.ONE_RTT, PACKET_TYPE_ONE_RTT)) + epoch_packet_types.append((tls.Epoch.ONE_RTT, QuicPacketType.ONE_RTT)) for epoch, packet_type in epoch_packet_types: crypto = self._cryptos[epoch] if crypto.send.is_valid(): @@ -605,9 +630,9 @@ def datagrams_to_send(self, now: float) -> list[tuple[bytes, NetworkAddress]]: packet.packet_type ), "scid": ( - dump_cid(self.host_cid) - if is_long_header(packet.packet_type) - else "" + "" + if packet.packet_type == QuicPacketType.ONE_RTT + else dump_cid(self.host_cid) ), "dcid": dump_cid(self._peer_cid.cid), }, @@ -724,9 +749,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non if self._state in END_STATES: return + payload_length = len(data) + # log datagram if self._quic_logger is not None: - payload_length = len(data) self._quic_logger.log_event( category="transport", event="datagrams_received", @@ -741,6 +767,21 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non }, ) + # For anti-amplification purposes, servers need to keep track of the + # amount of data received on unvalidated network paths. We must count the + # entire datagram size regardless of whether packets are processed or + # dropped. + # + # This is particularly important when talking to clients who pad + # datagrams containing INITIAL packets by appending bytes after the + # long-header packets, which is legitimate behaviour. + # + # https://datatracker.ietf.org/doc/html/rfc9000#section-8.1 + network_path = self._find_network_path(addr) + + if not network_path.is_validated: + network_path.bytes_received += payload_length + # for servers, arm the idle timeout on the first datagram if self._close_at is None: self._close_at = now + self._configuration.idle_timeout @@ -779,129 +820,56 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non ) return - # check protocol version - if ( - self._is_client - and self._state == QuicConnectionState.FIRSTFLIGHT - and header.version == QuicProtocolVersion.NEGOTIATION - and not self._version_negotiation_count - ): - # version negotiation - versions = [] - while not buf.eof(): - versions.append(buf.pull_uint32()) - if self._quic_logger is not None: - self._quic_logger.log_event( - category="transport", - event="packet_received", - data={ - "frames": [], - "header": { - "packet_type": "version_negotiation", - "scid": dump_cid(header.source_cid), - "dcid": dump_cid(header.destination_cid), - }, - "raw": {"length": buf.tell() - start_off}, - }, - ) - if self._version in versions: - self._logger.warning( - "Version negotiation packet contains %s" % self._version - ) - return - common = [ - x for x in self._configuration.supported_versions if x in versions - ] - if not common: - self._logger.error("Could not find a common protocol version") - self._close_event = events.ConnectionTerminated( - error_code=QuicErrorCode.INTERNAL_ERROR, - frame_type=QuicFrameType.PADDING, - reason_phrase="Could not find a common protocol version", - ) - self._close_end() - return - self._packet_number = 0 - self._version = QuicProtocolVersion(common[0]) - self._version_negotiation_count += 1 - self._logger.debug("Retrying with %s", self._version) - self._connect(now=now) + # Handle version negotiation packet. + if header.packet_type == QuicPacketType.VERSION_NEGOTIATION: + self._receive_version_negotiation_packet(header=header, now=now) return - elif ( + + # Check long header packet protocol version. + if ( header.version is not None and header.version not in self._configuration.supported_versions ): - # unsupported version if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "unsupported_version"}, + data={ + "trigger": "unsupported_version", + "raw": {"length": header.packet_length}, + }, ) return # handle retry packet - if header.packet_type == PACKET_TYPE_RETRY: - if ( - self._is_client - and not self._retry_count - and header.destination_cid == self.host_cid - and header.integrity_tag - == get_retry_integrity_tag( - buf.data_slice( - start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE - ), - self._peer_cid.cid, - version=header.version, - ) - ): - if self._quic_logger is not None: - self._quic_logger.log_event( - category="transport", - event="packet_received", - data={ - "frames": [], - "header": { - "packet_type": "retry", - "scid": dump_cid(header.source_cid), - "dcid": dump_cid(header.destination_cid), - }, - "raw": {"length": buf.tell() - start_off}, - }, - ) - - self._peer_cid.cid = header.source_cid - self._peer_token = header.token - self._retry_count += 1 - self._retry_source_connection_id = header.source_cid - self._logger.debug( - "Retrying with token (%d bytes)" % len(header.token) - ) - self._connect(now=now) - else: - # unexpected or invalid retry packet - if self._quic_logger is not None: - self._quic_logger.log_event( - category="transport", - event="packet_dropped", - data={"trigger": "unexpected_packet"}, - ) + if header.packet_type == QuicPacketType.RETRY: + self._receive_retry_packet( + header=header, + packet_without_tag=buf.data_slice( + start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE + ), + now=now, + ) return - network_path = self._find_network_path(addr) + crypto_frame_required = False # server initialization if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT: assert ( - header.packet_type == PACKET_TYPE_INITIAL + header.packet_type == QuicPacketType.INITIAL ), "first packet must be INITIAL" + crypto_frame_required = True self._network_paths = [network_path] - self._version = QuicProtocolVersion(header.version) + self._version = header.version self._initialize(header.destination_cid) - # determine crypto and packet space + # Determine crypto and packet space. epoch = get_epoch(header.packet_type) - crypto = self._cryptos[epoch] + if epoch == tls.Epoch.INITIAL: + crypto = self._cryptos_initial[header.version] + else: + crypto = self._cryptos[epoch] if epoch == tls.Epoch.ZERO_RTT: space = self._spaces[tls.Epoch.ONE_RTT] else: @@ -909,7 +877,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non # decrypt packet encrypted_off = buf.tell() - start_off - end_off = buf.tell() + header.rest_length + end_off = start_off + header.packet_length buf.seek(end_off) try: @@ -922,7 +890,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "key_unavailable"}, + data={ + "trigger": "key_unavailable", + "raw": {"length": header.packet_length}, + }, ) # If a client receives HANDSHAKE or 1-RTT packets before it has @@ -941,15 +912,18 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "payload_decrypt_error"}, + data={ + "trigger": "payload_decrypt_error", + "raw": {"length": header.packet_length}, + }, ) continue # check reserved bits - if header.is_long_header: - reserved_mask = 0x0C - else: + if header.packet_type == QuicPacketType.ONE_RTT: reserved_mask = 0x18 + else: + reserved_mask = 0x0C if plain_header[0] & reserved_mask: self.close( error_code=QuicErrorCode.PROTOCOL_VIOLATION, @@ -975,7 +949,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non "dcid": dump_cid(header.destination_cid), "scid": dump_cid(header.source_cid), }, - "raw": {"length": end_off - start_off}, + "raw": {"length": header.packet_length}, }, ) @@ -997,7 +971,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._set_state(QuicConnectionState.CONNECTED) # update spin bit - if not header.is_long_header and packet_number > self._spin_highest_pn: + if ( + header.packet_type == QuicPacketType.ONE_RTT + and packet_number > self._spin_highest_pn + ): spin_bit = get_spin_bit(plain_header[0]) if self._is_client: self._spin_bit = not spin_bit @@ -1019,10 +996,11 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non network_path=network_path, quic_logger_frames=quic_logger_frames, time=now, + version=header.version, ) try: is_ack_eliciting, is_probing = self._payload_received( - context, plain_payload + context, plain_payload, crypto_frame_required ) except QuicConnectionError as exc: self._logger.warning(exc) @@ -1057,7 +1035,6 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non "Network path %s validated by handshake", network_path.addr ) network_path.is_validated = True - network_path.bytes_received += end_off - start_off if network_path not in self._network_paths: self._network_paths.append(network_path) idx = self._network_paths.index(network_path) @@ -1143,8 +1120,58 @@ def stop_stream(self, stream_id: int, error_code: int) -> None: def _alpn_handler(self, alpn_protocol: str) -> None: """ - Callback which is invoked by the TLS engine when ALPN negotiation completes. + Callback which is invoked by the TLS engine at most once, when the + ALPN negotiation completes. + + At this point, TLS extensions have been received so we can parse the + transport parameters. """ + # Parse the remote transport parameters. + for ext_type, ext_data in self.tls.received_extensions: + if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: + self._parse_transport_parameters(ext_data) + break + else: + raise QuicConnectionError( + error_code=QuicErrorCode.CRYPTO_ERROR + + tls.AlertDescription.missing_extension, + frame_type=self._crypto_frame_type, + reason_phrase="No QUIC transport parameters received", + ) + + # For servers, determine the Negotiated Version. + if not self._is_client and not self._version_negotiated_compatible: + if self._remote_version_information is not None: + # Pick the first version we support in the client's available versions, + # which is compatible with the current version. + for version in self._remote_version_information.available_versions: + if version == self._version: + # Stay with the current version. + break + elif ( + version in self._configuration.supported_versions + and is_version_compatible(self._version, version) + ): + # Change version. + self._version = version + self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[ + version + ] + + # Update our transport parameters to reflect the chosen version. + self.tls.handshake_extensions = [ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + self._serialize_transport_parameters(), + ) + ] + break + self._version_negotiated_compatible = True + self._logger.info( + "Negotiated protocol version %s", pretty_protocol_version(self._version) + ) + + # Notify the application. self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol)) def _assert_stream_can_receive(self, frame_type: int, stream_id: int) -> None: @@ -1223,6 +1250,13 @@ def _discard_epoch(self, epoch: tls.Epoch) -> None: if not self._spaces[epoch].discarded: self._logger.debug("Discarding epoch %s", epoch) self._cryptos[epoch].teardown() + if epoch == tls.Epoch.INITIAL: + # Tear the crypto pairs, but do not log the event, + # to avoid duplicate log entries. + for crypto in self._cryptos_initial.values(): + crypto.recv._teardown_cb = NoCallback + crypto.send._teardown_cb = NoCallback + crypto.teardown() self._loss.discard_space(self._spaces[epoch]) self._spaces[epoch].discarded = True @@ -1447,6 +1481,15 @@ def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair: send_teardown_cb=partial(self._log_key_retired, send_secret_name), ) + # To enable version negotiation, setup encryption keys for all + # our supported versions. + self._cryptos_initial = {} + + for version in self._configuration.supported_versions: + pair = CryptoPair() + pair.setup_initial(cid=peer_cid, is_client=self._is_client, version=version) + self._cryptos_initial[version] = pair + self._cryptos = { epoch: create_crypto_pair(epoch) for epoch in ( @@ -1456,6 +1499,9 @@ def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair: tls.Epoch.ONE_RTT, ) } + + self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[self._version] + self._crypto_buffers = { tls.Epoch.INITIAL: Buffer(capacity=CRYPTO_BUFFER_SIZE), tls.Epoch.HANDSHAKE: Buffer(capacity=CRYPTO_BUFFER_SIZE), @@ -1472,10 +1518,6 @@ def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair: tls.Epoch.ONE_RTT: QuicPacketSpace(), } - self._cryptos[tls.Epoch.INITIAL].setup_initial( - cid=peer_cid, is_client=self._is_client, version=self._version - ) - self._loss.spaces = list(self._spaces.values()) def _handle_ack_frame( @@ -1574,9 +1616,23 @@ def _handle_crypto_frame( ) stream = self._crypto_streams[context.epoch] + + pending = offset + length - stream.receiver.starting_offset() + + if pending > MAX_PENDING_CRYPTO: + raise QuicConnectionError( + error_code=QuicErrorCode.CRYPTO_BUFFER_EXCEEDED, + frame_type=frame_type, + reason_phrase="too much crypto buffering", + ) + event = stream.receiver.handle_frame(frame) if event is not None: - # pass data to TLS layer + # Pass data to TLS layer, which may cause calls to: + # - _alpn_handler + # - _update_traffic_key + self._crypto_frame_type = frame_type + self._crypto_packet_version = context.version try: self.tls.handle_message(event.data, self._crypto_buffers) self._push_crypto_data() @@ -1587,24 +1643,6 @@ def _handle_crypto_frame( reason_phrase=str(exc), ) - # parse transport parameters - if ( - not self._parameters_received - and self.tls.received_extensions is not None - ): - for ext_type, ext_data in self.tls.received_extensions: - if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: - self._parse_transport_parameters(ext_data) - self._parameters_received = True - break - if not self._parameters_received: - raise QuicConnectionError( - error_code=QuicErrorCode.CRYPTO_ERROR - + tls.AlertDescription.missing_extension, - frame_type=frame_type, - reason_phrase="No QUIC transport parameters received", - ) - # update current epoch if not self._handshake_complete and self.tls.state in [ tls.State.CLIENT_POST_HANDSHAKE, @@ -1865,24 +1903,30 @@ def _handle_new_connection_id_frame( reason_phrase="Retire Prior To is greater than Sequence Number", ) + # only accept retire_prior_to if it is bigger than the one we know + self._peer_retire_prior_to = max(retire_prior_to, self._peer_retire_prior_to) + # determine which CIDs to retire change_cid = False - retire = list( - filter( - lambda c: c.sequence_number < retire_prior_to, self._peer_cid_available - ) - ) + retire = [ + cid + for cid in self._peer_cid_available + if cid.sequence_number < self._peer_retire_prior_to + ] if self._peer_cid.sequence_number < retire_prior_to: change_cid = True retire.insert(0, self._peer_cid) # update available CIDs - self._peer_cid_available = list( - filter( - lambda c: c.sequence_number >= retire_prior_to, self._peer_cid_available - ) - ) - if sequence_number not in self._peer_cid_sequence_numbers: + self._peer_cid_available = [ + cid + for cid in self._peer_cid_available + if cid.sequence_number >= self._peer_retire_prior_to + ] + if ( + sequence_number >= self._peer_retire_prior_to + and sequence_number not in self._peer_cid_sequence_numbers + ): self._peer_cid_available.append( QuicConnectionId( cid=connection_id, @@ -1908,6 +1952,21 @@ def _handle_new_connection_id_frame( reason_phrase="Too many active connection IDs", ) + # Check the number of retired connection IDs pending, though with a safer limit + # than the 2x recommended in section 5.1.2 of the RFC. Note that we are doing + # the check here and not in _retire_peer_cid() because we know the frame type to + # use here, and because it is the new connection id path that is potentially + # dangerous. We may transiently go a bit over the limit due to unacked frames + # getting added back to the list, but that's ok as it is bounded. + if len(self._retire_connection_ids) > min( + self._local_active_connection_id_limit * 4, MAX_PENDING_RETIRES + ): + raise QuicConnectionError( + error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR, + frame_type=frame_type, + reason_phrase="Too many pending retired connection IDs", + ) + def _handle_new_token_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: @@ -1962,7 +2021,7 @@ def _handle_path_challenge_frame( self._quic_logger.encode_path_challenge_frame(data=data) ) - context.network_path.remote_challenge = data + context.network_path.remote_challenges.append(data) def _handle_path_response_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer @@ -1978,16 +2037,16 @@ def _handle_path_response_frame( self._quic_logger.encode_path_response_frame(data=data) ) - if data != context.network_path.local_challenge: + try: + network_path = self._local_challenges.pop(data) + except KeyError: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Response does not match challenge", ) - self._logger.debug( - "Network path %s validated by challenge", context.network_path.addr - ) - context.network_path.is_validated = True + self._logger.debug("Network path %s validated by challenge", network_path.addr) + network_path.is_validated = True def _handle_ping_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer @@ -2325,30 +2384,42 @@ def _on_retire_connection_id_delivery( self._retire_connection_ids.append(sequence_number) def _payload_received( - self, context: QuicReceiveContext, plain: bytes + self, + context: QuicReceiveContext, + plain: bytes, + crypto_frame_required: bool = False, ) -> tuple[bool, bool]: """ Handle a QUIC packet payload. """ buf = Buffer(data=plain) + crypto_frame_found = False frame_found = False is_ack_eliciting = False is_probing = None while not buf.eof(): - frame_type = buf.pull_uint_var() + # get frame type + try: + frame_type = buf.pull_uint_var() + except BufferReadError: + raise QuicConnectionError( + error_code=QuicErrorCode.FRAME_ENCODING_ERROR, + frame_type=None, + reason_phrase="Malformed frame type", + ) # check frame type is known try: frame_handler, frame_epochs = self.__frame_handlers[frame_type] except KeyError: raise QuicConnectionError( - error_code=QuicErrorCode.PROTOCOL_VIOLATION, + error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Unknown frame type", ) - # check frame is allowed for the epoch + # check frame type is allowed for the epoch if context.epoch not in frame_epochs: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, @@ -2372,6 +2443,9 @@ def _payload_received( # update ACK only / probing flags frame_found = True + if frame_type == QuicFrameType.CRYPTO: + crypto_frame_found = True + if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: is_ack_eliciting = True @@ -2387,8 +2461,162 @@ def _payload_received( reason_phrase="Packet contains no frames", ) + # RFC 9000 - 17.2.2. Initial Packet + # The first packet sent by a client always includes a CRYPTO frame. + if crypto_frame_required and not crypto_frame_found: + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=QuicFrameType.PADDING, + reason_phrase="Packet contains no CRYPTO frame", + ) + return is_ack_eliciting, bool(is_probing) + def _receive_retry_packet( + self, header: QuicHeader, packet_without_tag: bytes, now: float + ) -> None: + """ + Handle a retry packet. + """ + if ( + self._is_client + and not self._retry_count + and header.destination_cid == self.host_cid + and header.integrity_tag + == get_retry_integrity_tag( + packet_without_tag, + self._peer_cid.cid, + version=header.version, + ) + ): + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_received", + data={ + "frames": [], + "header": { + "packet_type": "retry", + "scid": dump_cid(header.source_cid), + "dcid": dump_cid(header.destination_cid), + }, + "raw": {"length": header.packet_length}, + }, + ) + + self._peer_cid.cid = header.source_cid + self._peer_token = header.token + self._retry_count += 1 + self._retry_source_connection_id = header.source_cid + self._logger.info("Retrying with token (%d bytes)" % len(header.token)) + self._connect(now=now) + else: + # Unexpected or invalid retry packet. + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_dropped", + data={ + "trigger": "unexpected_packet", + "raw": {"length": header.packet_length}, + }, + ) + + def _receive_version_negotiation_packet( + self, header: QuicHeader, now: float + ) -> None: + """ + Handle a version negotiation packet. + + This is used in "Incompatible Version Negotiation", see: + https://datatracker.ietf.org/doc/html/rfc9368#section-2.2 + """ + # Only clients process Version Negotiation, and once a Version + # Negotiation packet has been acted upon, any further + # such packets must be ignored. + # + # https://datatracker.ietf.org/doc/html/rfc9368#section-4 + if ( + self._is_client + and self._state == QuicConnectionState.FIRSTFLIGHT + and not self._version_negotiated_incompatible + ): + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_received", + data={ + "frames": [], + "header": { + "packet_type": self._quic_logger.packet_type( + header.packet_type + ), + "scid": dump_cid(header.source_cid), + "dcid": dump_cid(header.destination_cid), + }, + "raw": {"length": header.packet_length}, + }, + ) + + # Ignore any Version Negotiation packets that contain the + # original version. + # + # https://datatracker.ietf.org/doc/html/rfc9368#section-4 + if self._version in header.supported_versions: + self._logger.warning( + "Version negotiation packet contains protocol version %s", + pretty_protocol_version(self._version), + ) + return + + # Look for a common protocol version. + common = [ + x + for x in self._configuration.supported_versions + if x in header.supported_versions + ] + + # Look for a common protocol version. + chosen_version = common[0] if common else None + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="version_information", + data={ + "server_versions": header.supported_versions, + "client_versions": self._configuration.supported_versions, + "chosen_version": chosen_version, + }, + ) + if chosen_version is None: + self._logger.error("Could not find a common protocol version") + self._close_event = events.ConnectionTerminated( + error_code=QuicErrorCode.INTERNAL_ERROR, + frame_type=QuicFrameType.PADDING, + reason_phrase="Could not find a common protocol version", + ) + self._close_end() + return + self._packet_number = 0 + self._version = chosen_version + self._version_negotiated_incompatible = True + self._logger.info( + "Retrying with protocol version %s", + pretty_protocol_version(self._version), + ) + self._connect(now=now) + else: + # Unexpected version negotiation packet. + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_dropped", + data={ + "trigger": "unexpected_packet", + "raw": {"length": header.packet_length}, + }, + ) + def _replenish_connection_ids(self) -> None: """ Generate new connection IDs. @@ -2408,9 +2636,10 @@ def _retire_peer_cid(self, connection_id: QuicConnectionId) -> None: Retire a destination connection ID. """ self._logger.debug( - "Retiring CID %s (%d)", + "Retiring CID %s (%d) [%d]", dump_cid(connection_id.cid), connection_id.sequence_number, + len(self._retire_connection_ids) + 1, ) self._retire_connection_ids.append(connection_id.sequence_number) @@ -2533,6 +2762,41 @@ def _parse_transport_parameters( reason_phrase="max_udp_payload_size must be >= 1200", ) + # Validate Version Information extension. + # + # https://datatracker.ietf.org/doc/html/rfc9368#section-4 + if quic_transport_parameters.version_information is not None: + version_information = quic_transport_parameters.version_information + + # If a server receives Version Information where the Chosen Version + # is not included in Available Versions, it MUST treat is as a + # parsing failure. + if ( + not self._is_client + and version_information.chosen_version + not in version_information.available_versions + ): + raise QuicConnectionError( + error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, + frame_type=QuicFrameType.CRYPTO, + reason_phrase=( + "version_information's chosen_version is not included " + "in available_versions" + ), + ) + + # Validate that the Chosen Version matches the version in use for the + # connection. + if version_information.chosen_version != self._crypto_packet_version: + raise QuicConnectionError( + error_code=QuicErrorCode.VERSION_NEGOTIATION_ERROR, + frame_type=QuicFrameType.CRYPTO, + reason_phrase=( + "version_information's chosen_version does not match " + "the version in use" + ), + ) + # store remote parameters if not from_session_ticket: if quic_transport_parameters.ack_delay_exponent is not None: @@ -2549,6 +2813,9 @@ def _parse_transport_parameters( self._peer_cid.stateless_reset_token = ( quic_transport_parameters.stateless_reset_token ) + self._remote_version_information = ( + quic_transport_parameters.version_information + ) if quic_transport_parameters.active_connection_id_limit is not None: self._remote_active_connection_id_limit = ( @@ -2591,6 +2858,10 @@ def _serialize_transport_parameters(self) -> bytes: b"Q" * 1200 if self._configuration.quantum_readiness_test else None ), stateless_reset_token=self._host_cids[0].stateless_reset_token, + version_information=QuicVersionInformation( + chosen_version=self._version, + available_versions=self._configuration.supported_versions, + ), ) if not self._is_client: quic_transport_parameters.original_destination_connection_id = ( @@ -2657,7 +2928,20 @@ def _update_traffic_key( Callback which is invoked by the TLS engine when new traffic keys are available. """ + # For clients, determine the negotiated protocol version. + if ( + self._is_client + and self._crypto_packet_version is not None + and not self._version_negotiated_compatible + ): + self._version = self._crypto_packet_version + self._version_negotiated_compatible = True + self._logger.info( + "Negotiated protocol version %s", pretty_protocol_version(self._version) + ) + secrets_log_file = self._configuration.secrets_log_file + if secrets_log_file is not None: label_row = self._is_client == (direction == tls.Direction.DECRYPT) label = SECRETS_LABELS[label_row][epoch.value] @@ -2676,6 +2960,14 @@ def _update_traffic_key( cipher_suite=cipher_suite, secret=secret, version=self._version ) + def _add_local_challenge(self, challenge: bytes, network_path: QuicNetworkPath): + self._local_challenges[challenge] = network_path + while len(self._local_challenges) > MAX_LOCAL_CHALLENGES: + # Dictionaries are ordered, so pop the first key until we are below the + # limit. + key = next(iter(self._local_challenges.keys())) + del self._local_challenges[key] + def _write_application( self, builder: QuicPacketBuilder, network_path: QuicNetworkPath, now: float ) -> None: @@ -2683,10 +2975,10 @@ def _write_application( if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ONE_RTT] crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT] - packet_type = PACKET_TYPE_ONE_RTT + packet_type = QuicPacketType.ONE_RTT elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ZERO_RTT] - packet_type = PACKET_TYPE_ZERO_RTT + packet_type = QuicPacketType.ZERO_RTT else: return space = self._spaces[tls.Epoch.ONE_RTT] @@ -2710,22 +3002,22 @@ def _write_application( self._handshake_done_pending = False # PATH CHALLENGE - if ( - not network_path.is_validated - and network_path.local_challenge is None - ): + if not (network_path.is_validated or network_path.local_challenge_sent): challenge = os.urandom(8) + self._add_local_challenge( + challenge=challenge, network_path=network_path + ) self._write_path_challenge_frame( builder=builder, challenge=challenge ) - network_path.local_challenge = challenge + network_path.local_challenge_sent = True # PATH RESPONSE - if network_path.remote_challenge is not None: + while len(network_path.remote_challenges) > 0: + challenge = network_path.remote_challenges.popleft() self._write_path_response_frame( - builder=builder, challenge=network_path.remote_challenge + builder=builder, challenge=challenge ) - network_path.remote_challenge = None # NEW_CONNECTION_ID for connection_id in self._host_cids: @@ -2792,8 +3084,8 @@ def _write_application( except QuicPacketBuilderStop: break - sent: set[QuicStream] = set() - discarded: set[QuicStream] = set() + to_reshelve: list[QuicStream] = [] + sent: list[QuicStream] = [] try: for stream in self._streams_queue: @@ -2802,7 +3094,6 @@ def _write_application( self._logger.debug("Stream %d discarded", stream.stream_id) self._streams.pop(stream.stream_id) self._streams_finished.add(stream.stream_id) - discarded.add(stream) continue if stream.receiver.stop_pending: @@ -2827,20 +3118,17 @@ def _write_application( ) self._remote_max_data_used += used if used > 0: - sent.add(stream) + sent.append(stream) + continue + to_reshelve.append(stream) finally: # Make a new stream service order, putting served ones at the end. # # This method of updating the streams queue ensures that discarded # streams are removed and ones which sent are moved to the end even # if an exception occurs in the loop. - self._streams_queue = [ - stream - for stream in self._streams_queue - if not (stream in discarded or stream in sent) - ] - self._streams_queue.extend(sent) + self._streams_queue = to_reshelve + sent if builder.packet_is_empty: break @@ -2859,9 +3147,10 @@ def _write_handshake( while True: if epoch == tls.Epoch.INITIAL: - packet_type = PACKET_TYPE_INITIAL + packet_type = QuicPacketType.INITIAL else: - packet_type = PACKET_TYPE_HANDSHAKE + packet_type = QuicPacketType.HANDSHAKE + builder.start_packet(packet_type, crypto) # ACK diff --git a/qh3/quic/crypto.py b/qh3/quic/crypto.py index cfb64832a..3b9dc4e4e 100644 --- a/qh3/quic/crypto.py +++ b/qh3/quic/crypto.py @@ -5,7 +5,11 @@ from .._crypto import AEAD, CryptoError, HeaderProtection from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract -from .packet import decode_packet_number, is_long_header +from .packet import ( + QuicProtocolVersion, + decode_packet_number, + is_long_header, +) CIPHER_SUITES = { CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"), @@ -14,6 +18,7 @@ } INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256 INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a") +INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9") SAMPLE_SIZE = 16 @@ -29,9 +34,10 @@ class KeyUnavailableError(CryptoError): def derive_key_iv_hp( - cipher_suite: CipherSuite, secret: bytes + *, cipher_suite: CipherSuite, secret: bytes, version: int ) -> tuple[bytes, bytes, bytes]: algorithm = cipher_suite_hash(cipher_suite) + if cipher_suite in [ CipherSuite.AES_256_GCM_SHA384, CipherSuite.CHACHA20_POLY1305_SHA256, @@ -39,11 +45,19 @@ def derive_key_iv_hp( key_size = 32 else: key_size = 16 - return ( - hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size), - hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12), - hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size), - ) + + if version == QuicProtocolVersion.VERSION_2: + return ( + hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size), + hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12), + hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size), + ) + else: + return ( + hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size), + hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12), + hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size), + ) class CryptoContext: @@ -108,10 +122,14 @@ def encrypt_packet( def is_valid(self) -> bool: return self.aead is not None - def setup(self, cipher_suite: CipherSuite, secret: bytes, version: int) -> None: + def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None: hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite] - key, iv, hp = derive_key_iv_hp(cipher_suite, secret) + key, iv, hp = derive_key_iv_hp( + cipher_suite=cipher_suite, + secret=secret, + version=version, + ) self.aead = AEAD(aead_cipher_name, key, iv) self.cipher_suite = cipher_suite self.hp = HeaderProtection(hp_cipher_name, hp) @@ -190,7 +208,11 @@ def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None: else: recv_label, send_label = b"client in", b"server in" - initial_salt = INITIAL_SALT_VERSION_1 + if version == QuicProtocolVersion.VERSION_2: + initial_salt = INITIAL_SALT_VERSION_2 + else: + initial_salt = INITIAL_SALT_VERSION_1 + algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE) digest_size = int(algorithm / 8) initial_secret = hkdf_extract(algorithm, initial_salt, cid) diff --git a/qh3/quic/events.py b/qh3/quic/events.py index b6fad53d0..2e08d345c 100644 --- a/qh3/quic/events.py +++ b/qh3/quic/events.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional class QuicEvent: @@ -29,7 +30,7 @@ class ConnectionTerminated(QuicEvent): error_code: int "The error code which was specified when closing the connection." - frame_type: Optional[int] + frame_type: int | None "The frame type which caused the connection to be closed, or `None`." reason_phrase: str @@ -52,7 +53,7 @@ class HandshakeCompleted(QuicEvent): The HandshakeCompleted event is fired when the TLS handshake completes. """ - alpn_protocol: Optional[str] + alpn_protocol: str | None "The protocol which was negotiated using ALPN, or `None`." early_data_accepted: bool @@ -78,7 +79,7 @@ class ProtocolNegotiated(QuicEvent): The ProtocolNegotiated event is fired when ALPN negotiation completes. """ - alpn_protocol: Optional[str] + alpn_protocol: str | None "The protocol which was negotiated using ALPN, or `None`." diff --git a/qh3/quic/logger.py b/qh3/quic/logger.py index 08c3ef42d..4f6c907c8 100644 --- a/qh3/quic/logger.py +++ b/qh3/quic/logger.py @@ -1,30 +1,28 @@ +from __future__ import annotations + import binascii import json import os import time from collections import deque -from typing import Any, Deque, Dict, List, Optional +from typing import Any, Deque from ..h3.events import Headers from .packet import ( - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_MASK, - PACKET_TYPE_ONE_RTT, - PACKET_TYPE_RETRY, - PACKET_TYPE_ZERO_RTT, QuicFrameType, + QuicPacketType, QuicStreamFrame, QuicTransportParameters, ) from .rangeset import RangeSet PACKET_TYPE_NAMES = { - PACKET_TYPE_INITIAL: "initial", - PACKET_TYPE_HANDSHAKE: "handshake", - PACKET_TYPE_ZERO_RTT: "0RTT", - PACKET_TYPE_ONE_RTT: "1RTT", - PACKET_TYPE_RETRY: "retry", + QuicPacketType.INITIAL: "initial", + QuicPacketType.HANDSHAKE: "handshake", + QuicPacketType.ZERO_RTT: "0RTT", + QuicPacketType.ONE_RTT: "1RTT", + QuicPacketType.RETRY: "retry", + QuicPacketType.VERSION_NEGOTIATION: "version_negotiation", } QLOG_VERSION = "0.3" @@ -47,7 +45,7 @@ class QuicLoggerTrace: def __init__(self, *, is_client: bool, odcid: bytes) -> None: self._odcid = odcid - self._events: Deque[Dict[str, Any]] = deque() + self._events: Deque[dict[str, Any]] = deque() self._vantage_point = { "name": "qh3", "type": "client" if is_client else "server", @@ -55,7 +53,7 @@ def __init__(self, *, is_client: bool, odcid: bytes) -> None: # QUIC - def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict: + def encode_ack_frame(self, ranges: RangeSet, delay: float) -> dict: return { "ack_delay": self.encode_time(delay), "acked_ranges": [[x.start, x.stop - 1] for x in ranges], @@ -63,8 +61,8 @@ def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict: } def encode_connection_close_frame( - self, error_code: int, frame_type: Optional[int], reason_phrase: str - ) -> Dict: + self, error_code: int, frame_type: int | None, reason_phrase: str + ) -> dict: attrs = { "error_code": error_code, "error_space": "application" if frame_type is None else "transport", @@ -77,7 +75,7 @@ def encode_connection_close_frame( return attrs - def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict: + def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> dict: if frame_type == QuicFrameType.MAX_DATA: return {"frame_type": "max_data", "maximum": maximum} else: @@ -91,23 +89,23 @@ def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict: ), } - def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict: + def encode_crypto_frame(self, frame: QuicStreamFrame) -> dict: return { "frame_type": "crypto", "length": len(frame.data), "offset": frame.offset, } - def encode_data_blocked_frame(self, limit: int) -> Dict: + def encode_data_blocked_frame(self, limit: int) -> dict: return {"frame_type": "data_blocked", "limit": limit} - def encode_datagram_frame(self, length: int) -> Dict: + def encode_datagram_frame(self, length: int) -> dict: return {"frame_type": "datagram", "length": length} - def encode_handshake_done_frame(self) -> Dict: + def encode_handshake_done_frame(self) -> dict: return {"frame_type": "handshake_done"} - def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict: + def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> dict: return { "frame_type": "max_stream_data", "maximum": maximum, @@ -120,7 +118,7 @@ def encode_new_connection_id_frame( retire_prior_to: int, sequence_number: int, stateless_reset_token: bytes, - ) -> Dict: + ) -> dict: return { "connection_id": hexdump(connection_id), "frame_type": "new_connection_id", @@ -130,28 +128,28 @@ def encode_new_connection_id_frame( "sequence_number": sequence_number, } - def encode_new_token_frame(self, token: bytes) -> Dict: + def encode_new_token_frame(self, token: bytes) -> dict: return { "frame_type": "new_token", "length": len(token), "token": hexdump(token), } - def encode_padding_frame(self) -> Dict: + def encode_padding_frame(self) -> dict: return {"frame_type": "padding"} - def encode_path_challenge_frame(self, data: bytes) -> Dict: + def encode_path_challenge_frame(self, data: bytes) -> dict: return {"data": hexdump(data), "frame_type": "path_challenge"} - def encode_path_response_frame(self, data: bytes) -> Dict: + def encode_path_response_frame(self, data: bytes) -> dict: return {"data": hexdump(data), "frame_type": "path_response"} - def encode_ping_frame(self) -> Dict: + def encode_ping_frame(self) -> dict: return {"frame_type": "ping"} def encode_reset_stream_frame( self, error_code: int, final_size: int, stream_id: int - ) -> Dict: + ) -> dict: return { "error_code": error_code, "final_size": final_size, @@ -159,27 +157,27 @@ def encode_reset_stream_frame( "stream_id": stream_id, } - def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict: + def encode_retire_connection_id_frame(self, sequence_number: int) -> dict: return { "frame_type": "retire_connection_id", "sequence_number": sequence_number, } - def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict: + def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> dict: return { "frame_type": "stream_data_blocked", "limit": limit, "stream_id": stream_id, } - def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict: + def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> dict: return { "frame_type": "stop_sending", "error_code": error_code, "stream_id": stream_id, } - def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict: + def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> dict: return { "fin": frame.fin, "frame_type": "stream", @@ -188,7 +186,7 @@ def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict: "stream_id": stream_id, } - def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict: + def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> dict: return { "frame_type": "streams_blocked", "limit": limit, @@ -203,8 +201,8 @@ def encode_time(self, seconds: float) -> float: def encode_transport_parameters( self, owner: str, parameters: QuicTransportParameters - ) -> Dict[str, Any]: - data: Dict[str, Any] = {"owner": owner} + ) -> dict[str, Any]: + data: dict[str, Any] = {"owner": owner} for param_name, param_value in parameters.__dict__.items(): if isinstance(param_value, bool): data[param_name] = param_value @@ -214,12 +212,12 @@ def encode_transport_parameters( data[param_name] = param_value return data - def packet_type(self, packet_type: int) -> str: - return PACKET_TYPE_NAMES.get(packet_type & PACKET_TYPE_MASK, "1RTT") + def packet_type(self, packet_type: QuicPacketType) -> str: + return PACKET_TYPE_NAMES[packet_type] # HTTP/3 - def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict: + def encode_http3_data_frame(self, length: int, stream_id: int) -> dict: return { "frame": {"frame_type": "data"}, "length": length, @@ -228,7 +226,7 @@ def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict: def encode_http3_headers_frame( self, length: int, headers: Headers, stream_id: int - ) -> Dict: + ) -> dict: return { "frame": { "frame_type": "headers", @@ -240,7 +238,7 @@ def encode_http3_headers_frame( def encode_http3_push_promise_frame( self, length: int, headers: Headers, push_id: int, stream_id: int - ) -> Dict: + ) -> dict: return { "frame": { "frame_type": "push_promise", @@ -251,14 +249,14 @@ def encode_http3_push_promise_frame( "stream_id": stream_id, } - def _encode_http3_headers(self, headers: Headers) -> List[Dict]: + def _encode_http3_headers(self, headers: Headers) -> list[dict]: return [ {"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers ] # CORE - def log_event(self, *, category: str, event: str, data: Dict) -> None: + def log_event(self, *, category: str, event: str, data: dict) -> None: self._events.append( { "data": data, @@ -267,7 +265,7 @@ def log_event(self, *, category: str, event: str, data: Dict) -> None: } ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Return the trace as a dictionary which can be written as JSON. """ @@ -286,7 +284,7 @@ class QuicLogger: """ def __init__(self) -> None: - self._traces: List[QuicLoggerTrace] = [] + self._traces: list[QuicLoggerTrace] = [] def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace: trace = QuicLoggerTrace(is_client=is_client, odcid=odcid) @@ -296,7 +294,7 @@ def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace: def end_trace(self, trace: QuicLoggerTrace) -> None: assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Return the traces as a dictionary which can be written as JSON. """ diff --git a/qh3/quic/packet.py b/qh3/quic/packet.py index 62fb15ece..52ab2d319 100644 --- a/qh3/quic/packet.py +++ b/qh3/quic/packet.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import binascii import ipaddress import os from dataclasses import dataclass -from enum import IntEnum -from typing import List, Optional, Tuple +from enum import Enum, IntEnum from .._hazmat import AeadAes128Gcm from ..buffer import Buffer @@ -13,17 +14,12 @@ PACKET_FIXED_BIT = 0x40 PACKET_SPIN_BIT = 0x20 -PACKET_TYPE_INITIAL = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x00 -PACKET_TYPE_ZERO_RTT = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x10 -PACKET_TYPE_HANDSHAKE = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x20 -PACKET_TYPE_RETRY = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x30 -PACKET_TYPE_ONE_RTT = PACKET_FIXED_BIT -PACKET_TYPE_MASK = 0xF0 - CONNECTION_ID_MAX_SIZE = 20 PACKET_NUMBER_MAX_SIZE = 4 RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e") +RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92") RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb") +RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a") RETRY_INTEGRITY_TAG_SIZE = 16 STATELESS_RESET_TOKEN_SIZE = 16 @@ -45,24 +41,78 @@ class QuicErrorCode(IntEnum): CRYPTO_BUFFER_EXCEEDED = 0xD KEY_UPDATE_ERROR = 0xE AEAD_LIMIT_REACHED = 0xF + VERSION_NEGOTIATION_ERROR = 0x11 CRYPTO_ERROR = 0x100 +class QuicPacketType(Enum): + INITIAL = 0 + ZERO_RTT = 1 + HANDSHAKE = 2 + RETRY = 3 + VERSION_NEGOTIATION = 4 + ONE_RTT = 5 + + +# For backwards compatibility only, use `QuicPacketType` in new code. +PACKET_TYPE_INITIAL = QuicPacketType.INITIAL + +# QUIC version 1 +# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 +PACKET_LONG_TYPE_ENCODE_VERSION_1 = { + QuicPacketType.INITIAL: 0, + QuicPacketType.ZERO_RTT: 1, + QuicPacketType.HANDSHAKE: 2, + QuicPacketType.RETRY: 3, +} +PACKET_LONG_TYPE_DECODE_VERSION_1 = { + v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items() +} + +# QUIC version 2 +# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2 +PACKET_LONG_TYPE_ENCODE_VERSION_2 = { + QuicPacketType.INITIAL: 1, + QuicPacketType.ZERO_RTT: 2, + QuicPacketType.HANDSHAKE: 3, + QuicPacketType.RETRY: 0, +} +PACKET_LONG_TYPE_DECODE_VERSION_2 = { + v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items() +} + + class QuicProtocolVersion(IntEnum): NEGOTIATION = 0 VERSION_1 = 0x00000001 + VERSION_2 = 0x6B3343CF @dataclass class QuicHeader: - is_long_header: bool - version: Optional[int] - packet_type: int + version: int | None + "The protocol version. Only present in long header packets." + + packet_type: QuicPacketType + "The type of the packet." + + packet_length: int + "The total length of the packet, in bytes." + destination_cid: bytes + "The destination connection ID." + source_cid: bytes - token: bytes = b"" - integrity_tag: bytes = b"" - rest_length: int = 0 + "The destination connection ID." + + token: bytes + "The address verification token. Only present in `INITIAL` and `RETRY` packets." + + integrity_tag: bytes + "The retry integrity tag. Only present in `RETRY` packets." + + supported_versions: list[int] + "Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets." def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int: @@ -95,8 +145,12 @@ def get_retry_integrity_tag( buf.push_bytes(packet_without_tag) assert buf.eof() - aead_key = RETRY_AEAD_KEY_VERSION_1 - aead_nonce = RETRY_AEAD_NONCE_VERSION_1 + if version == QuicProtocolVersion.VERSION_2: + aead_key = RETRY_AEAD_KEY_VERSION_2 + aead_nonce = RETRY_AEAD_NONCE_VERSION_2 + else: + aead_key = RETRY_AEAD_KEY_VERSION_1 + aead_nonce = RETRY_AEAD_NONCE_VERSION_1 # run AES-128-GCM aead = AeadAes128Gcm(aead_key) @@ -113,13 +167,30 @@ def is_long_header(first_byte: int) -> bool: return bool(first_byte & PACKET_LONG_HEADER) -def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader: - first_byte = buf.pull_uint8() +def pretty_protocol_version(version: int) -> str: + """ + Return a user-friendly representation of a protocol version. + """ + try: + version_name = QuicProtocolVersion(version).name + except ValueError: + version_name = "UNKNOWN" + return f"0x{version:08x} ({version_name})" + + +def pull_quic_header(buf: Buffer, host_cid_length: int | None = None) -> QuicHeader: + packet_start = buf.tell() + version = None integrity_tag = b"" + supported_versions = [] token = b"" + + first_byte = buf.pull_uint8() + if is_long_header(first_byte): - # long header packet + # Long Header Packets. + # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 version = buf.pull_uint32() destination_cid_length = buf.pull_uint8() @@ -135,56 +206,84 @@ def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> Quic source_cid = buf.pull_bytes(source_cid_length) if version == QuicProtocolVersion.NEGOTIATION: - # version negotiation - packet_type = None - rest_length = buf.capacity - buf.tell() + # Version Negotiation Packet. + # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1 + packet_type = QuicPacketType.VERSION_NEGOTIATION + while not buf.eof(): + supported_versions.append(buf.pull_uint32()) + packet_end = buf.tell() else: if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") - packet_type = first_byte & PACKET_TYPE_MASK - if packet_type == PACKET_TYPE_INITIAL: + if version == QuicProtocolVersion.VERSION_2: + packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[ + (first_byte & 0x30) >> 4 + ] + else: + packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[ + (first_byte & 0x30) >> 4 + ] + + if packet_type == QuicPacketType.INITIAL: token_length = buf.pull_uint_var() token = buf.pull_bytes(token_length) rest_length = buf.pull_uint_var() - elif packet_type == PACKET_TYPE_RETRY: + elif packet_type == QuicPacketType.ZERO_RTT: + rest_length = buf.pull_uint_var() + elif packet_type == QuicPacketType.HANDSHAKE: + rest_length = buf.pull_uint_var() + else: token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE token = buf.pull_bytes(token_length) integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE) rest_length = 0 - else: - rest_length = buf.pull_uint_var() - # check remainder length - if rest_length > buf.capacity - buf.tell(): + # Check remainder length. + packet_end = buf.tell() + rest_length + + if packet_end > buf.capacity: raise ValueError("Packet payload is truncated") - return QuicHeader( - is_long_header=True, - version=version, - packet_type=packet_type, - destination_cid=destination_cid, - source_cid=source_cid, - token=token, - integrity_tag=integrity_tag, - rest_length=rest_length, - ) else: - # short header packet + # https://datatracker.ietf.org/doc/html/rfc9000#section-17.3 if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") - packet_type = first_byte & PACKET_TYPE_MASK + version = None + packet_type = QuicPacketType.ONE_RTT destination_cid = buf.pull_bytes(host_cid_length) - return QuicHeader( - is_long_header=False, - version=None, - packet_type=packet_type, - destination_cid=destination_cid, - source_cid=b"", - token=b"", - rest_length=buf.capacity - buf.tell(), - ) + source_cid = b"" + packet_end = buf.capacity + + return QuicHeader( + version=version, + packet_type=packet_type, + packet_length=packet_end - packet_start, + destination_cid=destination_cid, + source_cid=source_cid, + token=token, + integrity_tag=integrity_tag, + supported_versions=supported_versions, + ) + + +def encode_long_header_first_byte( + version: int, packet_type: QuicPacketType, bits: int +) -> int: + """ + Encode the first byte of a long header packet. + """ + if version == QuicProtocolVersion.VERSION_2: + long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2 + else: + long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1 + return ( + PACKET_LONG_HEADER + | PACKET_FIXED_BIT + | long_type_encode[packet_type] << 4 + | bits + ) def encode_quic_retry( @@ -193,6 +292,7 @@ def encode_quic_retry( destination_cid: bytes, original_destination_cid: bytes, retry_token: bytes, + unused: int = 0, ) -> bytes: buf = Buffer( capacity=7 @@ -201,7 +301,7 @@ def encode_quic_retry( + len(retry_token) + RETRY_INTEGRITY_TAG_SIZE ) - buf.push_uint8(PACKET_TYPE_RETRY) + buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused)) buf.push_uint32(version) buf.push_uint8(len(destination_cid)) buf.push_bytes(destination_cid) @@ -216,7 +316,7 @@ def encode_quic_retry( def encode_quic_version_negotiation( - source_cid: bytes, destination_cid: bytes, supported_versions: List[int] + source_cid: bytes, destination_cid: bytes, supported_versions: list[int] ) -> bytes: buf = Buffer( capacity=7 @@ -240,33 +340,40 @@ def encode_quic_version_negotiation( @dataclass class QuicPreferredAddress: - ipv4_address: Optional[Tuple[str, int]] - ipv6_address: Optional[Tuple[str, int]] + ipv4_address: tuple[str, int] | None + ipv6_address: tuple[str, int] | None connection_id: bytes stateless_reset_token: bytes +@dataclass +class QuicVersionInformation: + chosen_version: int + available_versions: list[int] + + @dataclass class QuicTransportParameters: - original_destination_connection_id: Optional[bytes] = None - max_idle_timeout: Optional[int] = None - stateless_reset_token: Optional[bytes] = None - max_udp_payload_size: Optional[int] = None - initial_max_data: Optional[int] = None - initial_max_stream_data_bidi_local: Optional[int] = None - initial_max_stream_data_bidi_remote: Optional[int] = None - initial_max_stream_data_uni: Optional[int] = None - initial_max_streams_bidi: Optional[int] = None - initial_max_streams_uni: Optional[int] = None - ack_delay_exponent: Optional[int] = None - max_ack_delay: Optional[int] = None - disable_active_migration: Optional[bool] = False - preferred_address: Optional[QuicPreferredAddress] = None - active_connection_id_limit: Optional[int] = None - initial_source_connection_id: Optional[bytes] = None - retry_source_connection_id: Optional[bytes] = None - max_datagram_frame_size: Optional[int] = None - quantum_readiness: Optional[bytes] = None + original_destination_connection_id: bytes | None = None + max_idle_timeout: int | None = None + stateless_reset_token: bytes | None = None + max_udp_payload_size: int | None = None + initial_max_data: int | None = None + initial_max_stream_data_bidi_local: int | None = None + initial_max_stream_data_bidi_remote: int | None = None + initial_max_stream_data_uni: int | None = None + initial_max_streams_bidi: int | None = None + initial_max_streams_uni: int | None = None + ack_delay_exponent: int | None = None + max_ack_delay: int | None = None + disable_active_migration: bool | None = False + preferred_address: QuicPreferredAddress | None = None + active_connection_id_limit: int | None = None + initial_source_connection_id: bytes | None = None + retry_source_connection_id: bytes | None = None + version_information: QuicVersionInformation | None = None + max_datagram_frame_size: int | None = None + quantum_readiness: bytes | None = None PARAMS = { @@ -287,6 +394,8 @@ class QuicTransportParameters: 0x0E: ("active_connection_id_limit", int), 0x0F: ("initial_source_connection_id", bytes), 0x10: ("retry_source_connection_id", bytes), + # https://datatracker.ietf.org/doc/html/rfc9368#section-3 + 0x11: ("version_information", QuicVersionInformation), # extensions 0x0020: ("max_datagram_frame_size", int), 0x0C37: ("quantum_readiness", bytes), @@ -338,6 +447,33 @@ def push_quic_preferred_address( buf.push_bytes(preferred_address.stateless_reset_token) +def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation: + chosen_version = buf.pull_uint32() + available_versions = [] + for i in range(length // 4 - 1): + available_versions.append(buf.pull_uint32()) + + # If an endpoint receives a Chosen Version equal to zero, or any Available Version + # equal to zero, it MUST treat it as a parsing failure. + # + # https://datatracker.ietf.org/doc/html/rfc9368#section-4 + if chosen_version == 0 or 0 in available_versions: + raise ValueError("Version Information must not contain version 0") + + return QuicVersionInformation( + chosen_version=chosen_version, + available_versions=available_versions, + ) + + +def push_quic_version_information( + buf: Buffer, version_information: QuicVersionInformation +) -> None: + buf.push_uint32(version_information.chosen_version) + for version in version_information.available_versions: + buf.push_uint32(version) + + def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters: params = QuicTransportParameters() while not buf.eof(): @@ -347,18 +483,26 @@ def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters: if param_id in PARAMS: # parse known parameter param_name, param_type = PARAMS[param_id] - if param_type == int: + if param_type is int: setattr(params, param_name, buf.pull_uint_var()) - elif param_type == bytes: + elif param_type is bytes: setattr(params, param_name, buf.pull_bytes(param_len)) - elif param_type == QuicPreferredAddress: + elif param_type is QuicPreferredAddress: setattr(params, param_name, pull_quic_preferred_address(buf)) + elif param_type is QuicVersionInformation: + setattr( + params, + param_name, + pull_quic_version_information(buf, param_len), + ) else: setattr(params, param_name, True) else: # skip unknown parameter buf.pull_bytes(param_len) - assert buf.tell() == param_start + param_len + + if buf.tell() != param_start + param_len: + raise ValueError("Transport parameter length does not match") return params @@ -370,12 +514,14 @@ def push_quic_transport_parameters( param_value = getattr(params, param_name) if param_value is not None and param_value is not False: param_buf = Buffer(capacity=65536) - if param_type == int: + if param_type is int: param_buf.push_uint_var(param_value) - elif param_type == bytes: + elif param_type is bytes: param_buf.push_bytes(param_value) - elif param_type == QuicPreferredAddress: + elif param_type is QuicPreferredAddress: push_quic_preferred_address(param_buf, param_value) + elif param_type is QuicVersionInformation: + push_quic_version_information(param_buf, param_value) buf.push_uint_var(param_id) buf.push_uint_var(param_buf.tell()) buf.push_bytes(param_buf.data) @@ -461,7 +607,7 @@ class QuicStreamFrame: offset: int = 0 -def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]: +def pull_ack_frame(buf: Buffer) -> tuple[RangeSet, int]: rangeset = RangeSet() end = buf.pull_uint_var() # largest acknowledged delay = buf.pull_uint_var() diff --git a/qh3/quic/packet_builder.py b/qh3/quic/packet_builder.py index d54327ebe..3b33810a4 100644 --- a/qh3/quic/packet_builder.py +++ b/qh3/quic/packet_builder.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Callable, Sequence from ..buffer import Buffer, size_uint_var from ..tls import Epoch @@ -9,12 +11,11 @@ from .packet import ( NON_ACK_ELICITING_FRAME_TYPES, NON_IN_FLIGHT_FRAME_TYPES, + PACKET_FIXED_BIT, PACKET_NUMBER_MAX_SIZE, - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_MASK, QuicFrameType, - is_long_header, + QuicPacketType, + encode_long_header_first_byte, ) PACKET_MAX_SIZE = 1280 @@ -28,7 +29,6 @@ class QuicDeliveryState(Enum): ACKED = 0 LOST = 1 - EXPIRED = 2 @dataclass @@ -38,14 +38,14 @@ class QuicSentPacket: is_ack_eliciting: bool is_crypto_packet: bool packet_number: int - packet_type: int - sent_time: Optional[float] = None + packet_type: QuicPacketType + sent_time: float | None = None sent_bytes: int = 0 - delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field( + delivery_handlers: list[tuple[QuicDeliveryHandler, Any]] = field( default_factory=list ) - quic_logger_frames: List[Dict] = field(default_factory=list) + quic_logger_frames: list[dict] = field(default_factory=list) class QuicPacketBuilderStop(Exception): @@ -66,12 +66,12 @@ def __init__( is_client: bool, packet_number: int = 0, peer_token: bytes = b"", - quic_logger: Optional[QuicLoggerTrace] = None, + quic_logger: QuicLoggerTrace | None = None, spin_bit: bool = False, ): - self.max_flight_bytes: Optional[int] = None - self.max_total_bytes: Optional[int] = None - self.quic_logger_frames: Optional[List[Dict]] = None + self.max_flight_bytes: int | None = None + self.max_total_bytes: int | None = None + self.quic_logger_frames: list[dict] | None = None self._host_cid = host_cid self._is_client = is_client @@ -82,21 +82,22 @@ def __init__( self._version = version # assembled datagrams and packets - self._datagrams: List[bytes] = [] + self._datagrams: list[bytes] = [] self._datagram_flight_bytes = 0 self._datagram_init = True - self._packets: List[QuicSentPacket] = [] + self._datagram_needs_padding = False + self._packets: list[QuicSentPacket] = [] self._flight_bytes = 0 self._total_bytes = 0 # current packet self._header_size = 0 - self._packet: Optional[QuicSentPacket] = None - self._packet_crypto: Optional[CryptoPair] = None + self._packet: QuicSentPacket | None = None + self._packet_crypto: CryptoPair | None = None self._packet_long_header = False self._packet_number = packet_number self._packet_start = 0 - self._packet_type = 0 + self._packet_type: QuicPacketType | None = None self._buffer = Buffer(PACKET_MAX_SIZE) self._buffer_capacity = PACKET_MAX_SIZE @@ -142,7 +143,7 @@ def remaining_flight_space(self) -> int: - self._packet_crypto.aead_tag_size ) - def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]: + def flush(self) -> tuple[list[bytes], list[QuicSentPacket]]: """ Returns the assembled datagrams. """ @@ -160,7 +161,7 @@ def start_frame( self, frame_type: int, capacity: int = 1, - handler: Optional[QuicDeliveryHandler] = None, + handler: QuicDeliveryHandler | None = None, handler_args: Sequence[Any] = [], ) -> Buffer: """ @@ -183,10 +184,17 @@ def start_frame( self._packet.delivery_handlers.append((handler, handler_args)) return self._buffer - def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: + def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None: """ Starts a new packet. """ + assert packet_type in ( + QuicPacketType.INITIAL, + QuicPacketType.HANDSHAKE, + QuicPacketType.ZERO_RTT, + QuicPacketType.ONE_RTT, + ), "Invalid packet type" + buf = self._buffer # finish previous datagram @@ -214,12 +222,12 @@ def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: self._flight_capacity = remaining_flight_bytes self._datagram_flight_bytes = 0 self._datagram_init = False + self._datagram_needs_padding = False # calculate header size - packet_long_header = is_long_header(packet_type) - if packet_long_header: + if packet_type != QuicPacketType.ONE_RTT: header_size = 11 + len(self._peer_cid) + len(self._host_cid) - if (packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: + if packet_type == QuicPacketType.INITIAL: token_length = len(self._peer_token) header_size += size_uint_var(token_length) + token_length else: @@ -230,9 +238,9 @@ def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: raise QuicPacketBuilderStop # determine ack epoch - if packet_type == PACKET_TYPE_INITIAL: + if packet_type == QuicPacketType.INITIAL: epoch = Epoch.INITIAL - elif packet_type == PACKET_TYPE_HANDSHAKE: + elif packet_type == QuicPacketType.HANDSHAKE: epoch = Epoch.HANDSHAKE else: epoch = Epoch.ONE_RTT @@ -247,7 +255,6 @@ def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: packet_type=packet_type, ) self._packet_crypto = crypto - self._packet_long_header = packet_long_header self._packet_start = packet_start self._packet_type = packet_type self.quic_logger_frames = self._packet.quic_logger_frames @@ -269,15 +276,23 @@ def _end_packet(self) -> None: - packet_size ) - # padding for initial datagram + # Padding for datagrams containing initial packets; see RFC 9000 + # section 14.1. + if ( + self._is_client or self._packet.is_ack_eliciting + ) and self._packet_type == QuicPacketType.INITIAL: + self._datagram_needs_padding = True + + # For datagrams containing 1-RTT data, we *must* apply the padding + # inside the packet, we cannot tack bytes onto the end of the + # datagram. if ( - self._is_client - and self._packet_type == PACKET_TYPE_INITIAL - and self._packet.is_ack_eliciting - and self.remaining_flight_space - and self.remaining_flight_space > padding_size + self._datagram_needs_padding + and self._packet_type == QuicPacketType.ONE_RTT ): - padding_size = self.remaining_flight_space + if self.remaining_flight_space > padding_size: + padding_size = self.remaining_flight_space + self._datagram_needs_padding = False # write padding if padding_size > 0: @@ -292,7 +307,7 @@ def _end_packet(self) -> None: ) # write header - if self._packet_long_header: + if self._packet_type != QuicPacketType.ONE_RTT: length = ( packet_size - self._header_size @@ -301,13 +316,17 @@ def _end_packet(self) -> None: ) buf.seek(self._packet_start) - buf.push_uint8(self._packet_type | (PACKET_NUMBER_SEND_SIZE - 1)) + buf.push_uint8( + encode_long_header_first_byte( + self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1 + ) + ) buf.push_uint32(self._version) buf.push_uint8(len(self._peer_cid)) buf.push_bytes(self._peer_cid) buf.push_uint8(len(self._host_cid)) buf.push_bytes(self._host_cid) - if (self._packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: + if self._packet_type == QuicPacketType.INITIAL: buf.push_uint_var(len(self._peer_token)) buf.push_bytes(self._peer_token) buf.push_uint16(length | 0x4000) @@ -315,7 +334,7 @@ def _end_packet(self) -> None: else: buf.seek(self._packet_start) buf.push_uint8( - self._packet_type + PACKET_FIXED_BIT | (self._spin_bit << 5) | (self._packet_crypto.key_phase << 2) | (PACKET_NUMBER_SEND_SIZE - 1) @@ -338,8 +357,8 @@ def _end_packet(self) -> None: if self._packet.in_flight: self._datagram_flight_bytes += self._packet.sent_bytes - # short header packets cannot be coalesced, we need a new datagram - if not self._packet_long_header: + # Short header packets cannot be coalesced, we need a new datagram. + if self._packet_type == QuicPacketType.ONE_RTT: self._flush_current_datagram() self._packet_number += 1 @@ -353,6 +372,15 @@ def _end_packet(self) -> None: def _flush_current_datagram(self) -> None: datagram_bytes = self._buffer.tell() if datagram_bytes: + # Padding for datagrams containing initial packets; see RFC 9000 + # section 14.1. + if self._datagram_needs_padding: + extra_bytes = self._flight_capacity - self._buffer.tell() + if extra_bytes > 0: + self._buffer.push_bytes(bytes(extra_bytes)) + self._datagram_flight_bytes += extra_bytes + datagram_bytes += extra_bytes + self._datagrams.append(self._buffer.data) self._flight_bytes += self._datagram_flight_bytes self._total_bytes += datagram_bytes diff --git a/qh3/quic/rangeset.py b/qh3/quic/rangeset.py index a03180e19..8dcc6c75f 100644 --- a/qh3/quic/rangeset.py +++ b/qh3/quic/rangeset.py @@ -1,15 +1,17 @@ +from __future__ import annotations + from collections.abc import Sequence -from typing import Any, Iterable, List, Optional +from typing import Any, Iterable class RangeSet(Sequence): def __init__(self, ranges: Iterable[range] = []): - self.__ranges: List[range] = [] + self.__ranges: list[range] = [] for r in ranges: assert r.step == 1 self.add(r.start, r.stop) - def add(self, start: int, stop: Optional[int] = None) -> None: + def add(self, start: int, stop: int | None = None) -> None: if stop is None: stop = start + 1 assert stop > start diff --git a/qh3/quic/recovery.py b/qh3/quic/recovery.py index 2a3b4f15e..f07d3238b 100644 --- a/qh3/quic/recovery.py +++ b/qh3/quic/recovery.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import logging import math -from typing import Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Callable, Iterable from .logger import QuicLoggerTrace from .packet_builder import QuicDeliveryState, QuicSentPacket @@ -22,18 +24,18 @@ class QuicPacketSpace: def __init__(self) -> None: - self.ack_at: Optional[float] = None + self.ack_at: float | None = None self.ack_queue = RangeSet() self.discarded = False self.expected_packet_number = 0 self.largest_received_packet = -1 - self.largest_received_time: Optional[float] = None + self.largest_received_time: float | None = None # sent packets and loss self.ack_eliciting_in_flight = 0 self.largest_acked_packet = 0 - self.loss_time: Optional[float] = None - self.sent_packets: Dict[int, QuicSentPacket] = {} + self.loss_time: float | None = None + self.sent_packets: dict[int, QuicSentPacket] = {} class QuicPacketPacer: @@ -41,7 +43,7 @@ def __init__(self) -> None: self.bucket_max: float = 0.0 self.bucket_time: float = 0.0 self.evaluation_time: float = 0.0 - self.packet_time: Optional[float] = None + self.packet_time: float | None = None def next_send_time(self, now: float) -> float: if self.packet_time is not None: @@ -93,7 +95,7 @@ def __init__(self) -> None: self._congestion_recovery_start_time = 0.0 self._congestion_stash = 0 self._rtt_monitor = QuicRttMonitor() - self.ssthresh: Optional[int] = None + self.ssthresh: int | None = None def on_packet_acked(self, packet: QuicSentPacket) -> None: self.bytes_in_flight -= packet.sent_bytes @@ -155,12 +157,12 @@ def __init__( initial_rtt: float, peer_completed_address_validation: bool, send_probe: Callable[[], None], - logger: Optional[logging.LoggerAdapter] = None, - quic_logger: Optional[QuicLoggerTrace] = None, + logger: logging.LoggerAdapter | None = None, + quic_logger: QuicLoggerTrace | None = None, ) -> None: self.max_ack_delay = 0.025 self.peer_completed_address_validation = peer_completed_address_validation - self.spaces: List[QuicPacketSpace] = [] + self.spaces: list[QuicPacketSpace] = [] # callbacks self._logger = logger @@ -385,7 +387,7 @@ def _detect_loss(self, space: QuicPacketSpace, now: float) -> None: self._on_packets_lost(lost_packets, space=space, now=now) - def _get_loss_space(self) -> Optional[QuicPacketSpace]: + def _get_loss_space(self) -> QuicPacketSpace | None: loss_space = None for space in self.spaces: if space.loss_time is not None and ( @@ -395,7 +397,7 @@ def _get_loss_space(self) -> Optional[QuicPacketSpace]: return loss_space def _log_metrics_updated(self, log_rtt=False) -> None: - data: Dict[str, Any] = { + data: dict[str, Any] = { "bytes_in_flight": self._cc.bytes_in_flight, "cwnd": self._cc.congestion_window, } @@ -466,11 +468,11 @@ def __init__(self) -> None: self._ready = False self._size = 5 - self._filtered_min: Optional[float] = None + self._filtered_min: float | None = None self._sample_idx = 0 - self._sample_max: Optional[float] = None - self._sample_min: Optional[float] = None + self._sample_max: float | None = None + self._sample_min: float | None = None self._sample_time = 0.0 self._samples = [0.0 for i in range(self._size)] diff --git a/qh3/quic/retry.py b/qh3/quic/retry.py index 24358090c..6bde73957 100644 --- a/qh3/quic/retry.py +++ b/qh3/quic/retry.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import ipaddress -from typing import Tuple from .._hazmat import Rsa from ..buffer import Buffer @@ -27,7 +28,7 @@ def create_token( push_opaque(buf, 1, retry_source_connection_id) return self._key.encrypt(buf.data) - def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]: + def validate_token(self, addr: NetworkAddress, token: bytes) -> tuple[bytes, bytes]: if not token or len(token) != 256: raise ValueError("Ciphertext length must be equal to key size.") buf = Buffer(data=self._key.decrypt(token)) diff --git a/qh3/quic/stream.py b/qh3/quic/stream.py index 26e8ae74c..39b143098 100644 --- a/qh3/quic/stream.py +++ b/qh3/quic/stream.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from . import events from .packet import ( @@ -29,17 +29,17 @@ class QuicStreamReceiver: - upon reception of a data frame with the FIN bit set """ - def __init__(self, stream_id: Optional[int], readable: bool) -> None: + def __init__(self, stream_id: int | None, readable: bool) -> None: self.highest_offset = 0 # the highest offset ever seen self.is_finished = False self.stop_pending = False self._buffer = bytearray() self._buffer_start = 0 # the offset for the start of the buffer - self._final_size: Optional[int] = None + self._final_size: int | None = None self._ranges = RangeSet() self._stream_id = stream_id - self._stop_error_code: Optional[int] = None + self._stop_error_code: int | None = None def get_stop_frame(self) -> QuicStopSendingFrame: self.stop_pending = False @@ -48,9 +48,10 @@ def get_stop_frame(self) -> QuicStopSendingFrame: stream_id=self._stream_id, ) - def handle_frame( - self, frame: QuicStreamFrame - ) -> Optional[events.StreamDataReceived]: + def starting_offset(self) -> int: + return self._buffer_start + + def handle_frame(self, frame: QuicStreamFrame) -> events.StreamDataReceived | None: """ Handle a frame of received data. """ @@ -111,7 +112,7 @@ def handle_frame( def handle_reset( self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR - ) -> Optional[events.StreamReset]: + ) -> events.StreamReset | None: """ Handle an abrupt termination of the receiving part of the QUIC stream. """ @@ -166,7 +167,7 @@ class QuicStreamSender: - upon acknowledgement of a data frame with the FIN bit set """ - def __init__(self, stream_id: Optional[int], writable: bool) -> None: + def __init__(self, stream_id: int | None, writable: bool) -> None: self.buffer_is_empty = True self.highest_offset = 0 self.is_finished = not writable @@ -174,12 +175,12 @@ def __init__(self, stream_id: Optional[int], writable: bool) -> None: self._acked = RangeSet() self._buffer = bytearray() - self._buffer_fin: Optional[int] = None + self._buffer_fin: int | None = None self._buffer_start = 0 # the offset for the start of the buffer self._buffer_stop = 0 # the offset for the stop of the buffer self._pending = RangeSet() self._pending_eof = False - self._reset_error_code: Optional[int] = None + self._reset_error_code: int | None = None self._stream_id = stream_id @property @@ -195,8 +196,8 @@ def next_offset(self) -> int: return self._buffer_stop def get_frame( - self, max_size: int, max_offset: Optional[int] = None - ) -> Optional[QuicStreamFrame]: + self, max_size: int, max_offset: int | None = None + ) -> QuicStreamFrame | None: """ Get a frame of data to send. """ @@ -315,7 +316,7 @@ def write(self, data: bytes, end_stream: bool = False) -> None: class QuicStream: def __init__( self, - stream_id: Optional[int] = None, + stream_id: int | None = None, max_stream_data_local: int = 0, max_stream_data_remote: int = 0, readable: bool = True, diff --git a/qh3/tls.py b/qh3/tls.py index 406f0fddb..0085dedcf 100644 --- a/qh3/tls.py +++ b/qh3/tls.py @@ -35,6 +35,7 @@ SignatureError, UnacceptableCertificateError, X25519KeyExchange, + X25519Kyber768Draft00KeyExchange, verify_with_public_key, ) from .buffer import Buffer @@ -52,13 +53,21 @@ ] _HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$") +TLS_VERSION_GREASE = 0x0A0A TLS_VERSION_1_2 = 0x0303 TLS_VERSION_1_3 = 0x0304 T = TypeVar("T") # Maps the length of a digest to a possible hash function producing this digest -HASHFUNC_MAP = {32: hashlib.md5, 40: hashlib.sha1, 64: hashlib.sha256} +HASHFUNC_MAP = { + length: getattr(hashlib, algorithm, None) + for length, algorithm in ( + (32, "md5"), # some algorithm may be unavailable + (40, "sha1"), + (64, "sha256"), + ) +} # facilitate mocking for the test suite @@ -434,6 +443,7 @@ class CipherSuite(IntEnum): AES_256_GCM_SHA384 = 0x1302 CHACHA20_POLY1305_SHA256 = 0x1303 EMPTY_RENEGOTIATION_INFO_SCSV = 0x00FF + GREASE = 0xDADA class CompressionMethod(IntEnum): @@ -455,12 +465,14 @@ class ExtensionType(IntEnum): KEY_SHARE = 51 QUIC_TRANSPORT_PARAMETERS = 0x0039 ENCRYPTED_SERVER_NAME = 65486 + GREASE = 0x0A0A class Group(IntEnum): SECP256R1 = 0x0017 SECP384R1 = 0x0018 SECP521R1 = 0x0019 + X25519KYBER768DRAFT00 = 0x6399 X25519 = 0x001D X448 = 0x001E GREASE = 0xAAAA @@ -491,7 +503,7 @@ class SignatureAlgorithm(IntEnum): ECDSA_SECP384R1_SHA384 = 0x0503 ECDSA_SECP521R1_SHA512 = 0x0603 ED25519 = 0x0807 - ED448 = 0x0808 + ED448 = 0x0808 # unsupported RSA_PKCS1_SHA256 = 0x0401 RSA_PKCS1_SHA384 = 0x0501 RSA_PKCS1_SHA512 = 0x0601 @@ -716,6 +728,8 @@ def pull_extension() -> None: binders=pull_list(buf, 2, partial(pull_psk_binder, buf)), ) after_psk = True + elif extension_type == ExtensionType.GREASE: + pass # simply ignore it! else: hello.other_extensions.append( (extension_type, buf.pull_bytes(extension_length)) @@ -737,6 +751,9 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None: # extensions with push_block(buf, 2): + with push_extension(buf, ExtensionType.GREASE): + pass + with push_extension(buf, ExtensionType.KEY_SHARE): push_list(buf, 2, partial(push_key_share, buf), hello.key_share) @@ -1186,11 +1203,16 @@ def cipher_suite_hash(cipher_suite: CipherSuite) -> int: def negotiate( - supported: list[T], offered: list[Any] | None, exc: Alert | None = None + supported: list[T], + offered: list[Any] | None, + exc: Alert | None = None, + excl: T | None = None, ) -> T: if offered is not None: for c in supported: if c in offered: + if excl is not None and excl == c: + continue return c if exc is not None: @@ -1202,7 +1224,7 @@ def signature_algorithm_params(signature_algorithm: int) -> tuple[Any, ...]: if signature_algorithm in (SignatureAlgorithm.ED25519, SignatureAlgorithm.ED448): return () - is_pss, hash_size = SIGNATURE_ALGORITHMS[signature_algorithm] + is_pss, hash_size = SIGNATURE_ALGORITHMS[SignatureAlgorithm(signature_algorithm)] if is_pss is None: return () @@ -1286,7 +1308,7 @@ def __init__( self.certificate: X509Certificate | None = None self.certificate_chain: list[X509Certificate] = [] self.certificate_private_key: ( - RsaPrivateKey | DsaPrivateKey | EcPrivateKey | None + EcPrivateKey | Ed25519PrivateKey | DsaPrivateKey | RsaPrivateKey | None ) = None self.handshake_extensions: list[Extension] = [] self._max_early_data = max_early_data @@ -1311,6 +1333,7 @@ def __init__( self._cipher_suites = cipher_suites else: self._cipher_suites = [ + CipherSuite.GREASE, CipherSuite.AES_128_GCM_SHA256, CipherSuite.CHACHA20_POLY1305_SHA256, CipherSuite.AES_256_GCM_SHA384, @@ -1318,21 +1341,26 @@ def __init__( self._legacy_compression_methods: list[int] = [CompressionMethod.NULL] self._psk_key_exchange_modes: list[int] = [PskKeyExchangeMode.PSK_DHE_KE] self._signature_algorithms: list[int] = [ - SignatureAlgorithm.RSA_PSS_RSAE_SHA256, SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + SignatureAlgorithm.RSA_PSS_RSAE_SHA256, SignatureAlgorithm.RSA_PKCS1_SHA256, SignatureAlgorithm.ECDSA_SECP384R1_SHA384, + SignatureAlgorithm.RSA_PSS_RSAE_SHA384, + SignatureAlgorithm.RSA_PKCS1_SHA384, + SignatureAlgorithm.RSA_PSS_RSAE_SHA512, + SignatureAlgorithm.RSA_PKCS1_SHA512, SignatureAlgorithm.ED25519, ] self._supported_groups = [ + Group.GREASE, + Group.X25519KYBER768DRAFT00, Group.X25519, Group.SECP256R1, Group.SECP384R1, - # Group.SECP521R1, not used by default, but we can serve it. ] - self._supported_versions = [TLS_VERSION_1_3] + self._supported_versions = [TLS_VERSION_GREASE, TLS_VERSION_1_3] # state self.alpn_negotiated: str | None = None @@ -1356,6 +1384,9 @@ def __init__( self._ec_p384_private_key: ECDHP384KeyExchange | None = None self._ec_p521_private_key: ECDHP521KeyExchange | None = None self._x25519_private_key: X25519KeyExchange | None = None + self._x25519_kyber_768_private_key: X25519Kyber768Draft00KeyExchange | None = ( + None + ) if is_client: self.client_random = os.urandom(32) @@ -1514,6 +1545,15 @@ def _client_send_hello(self, output_buf: Buffer) -> None: self._x25519_private_key = X25519KeyExchange() key_share.append((Group.X25519, self._x25519_private_key.public_key())) supported_groups.append(Group.X25519) + elif group == Group.X25519KYBER768DRAFT00: + self._x25519_kyber_768_private_key = X25519Kyber768Draft00KeyExchange() + key_share.append( + ( + Group.X25519KYBER768DRAFT00, + self._x25519_kyber_768_private_key.public_key(), + ) + ) + supported_groups.append(Group.X25519KYBER768DRAFT00) elif group == Group.GREASE: key_share.append((Group.GREASE, b"\x00")) supported_groups.append(Group.GREASE) @@ -1557,7 +1597,7 @@ def _client_send_hello(self, output_buf: Buffer) -> None: ) # serialize hello without binder - tmp_buf = Buffer(capacity=1024) + tmp_buf = Buffer(capacity=2048) push_client_hello(tmp_buf, hello) # calculate binder @@ -1579,7 +1619,9 @@ def _client_send_hello(self, output_buf: Buffer) -> None: early_key, ) - self._key_schedule_proxy = KeyScheduleProxy(self._cipher_suites) + self._key_schedule_proxy = KeyScheduleProxy( + [cs for cs in self._cipher_suites if cs != CipherSuite.GREASE] + ) self._key_schedule_proxy.extract(None) with push_message(self._key_schedule_proxy, output_buf): @@ -1594,6 +1636,7 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None: self._cipher_suites, [peer_hello.cipher_suite], AlertHandshakeFailure("Unsupported cipher suite"), + excl=CipherSuite.GREASE, ) assert peer_hello.compression_method in self._legacy_compression_methods assert peer_hello.supported_version in self._supported_versions @@ -1622,6 +1665,8 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None: and self._x25519_private_key is not None ): shared_key = self._x25519_private_key.exchange(peer_public_key) + elif peer_hello.key_share[0] == Group.X25519KYBER768DRAFT00: + shared_key = self._x25519_kyber_768_private_key.exchange(peer_public_key) elif ( peer_hello.key_share[0] == Group.SECP256R1 and self._ec_p256_private_key is not None @@ -1720,7 +1765,7 @@ def _client_handle_certificate_verify(self, input_buf: Buffer) -> None: if self._assert_fingerprint is not None: fingerprint = self._assert_fingerprint.replace(":", "").lower() digest_length = len(fingerprint) - hashfunc = HASHFUNC_MAP.get(digest_length)() # type: ignore[abstract] + hashfunc = HASHFUNC_MAP.get(digest_length) if not hashfunc: raise AlertBadCertificate( @@ -1728,7 +1773,7 @@ def _client_handle_certificate_verify(self, input_buf: Buffer) -> None: ) expect_fingerprint = unhexlify(fingerprint.encode()) - peer_fingerprint = hashfunc(self._peer_certificate.public_bytes()).digiest() + peer_fingerprint = hashfunc(self._peer_certificate.public_bytes()).digest() if peer_fingerprint != expect_fingerprint: raise AlertBadCertificate( @@ -1874,14 +1919,13 @@ def _server_handle_hello( signature_algorithms = [SignatureAlgorithm.ECDSA_SECP521R1_SHA512] elif isinstance(self.certificate_private_key, Ed25519PrivateKey): signature_algorithms = [SignatureAlgorithm.ED25519] - # elif isinstance(self.certificate_private_key, ed448.Ed448PrivateKey): - # signature_algorithms = [SignatureAlgorithm.ED448] # negotiate parameters cipher_suite = negotiate( self._cipher_suites, peer_hello.cipher_suites, AlertHandshakeFailure("No supported cipher suite"), + excl=CipherSuite.GREASE, ) compression_method = negotiate( self._legacy_compression_methods, @@ -1909,14 +1953,15 @@ def _server_handle_hello( peer_hello.alpn_protocols, AlertHandshakeFailure("No common ALPN protocols"), ) - if self.alpn_cb: - self.alpn_cb(self.alpn_negotiated) self.client_random = peer_hello.random self.server_random = os.urandom(32) self.legacy_session_id = peer_hello.legacy_session_id self.received_extensions = peer_hello.other_extensions + if self.alpn_cb: + self.alpn_cb(self.alpn_negotiated) + # select key schedule pre_shared_key = None if ( @@ -1991,6 +2036,14 @@ def _server_handle_hello( shared_key = self._x25519_private_key.exchange(peer_public_key) group_kx = Group.X25519 break + elif key_share[0] == Group.X25519KYBER768DRAFT00: + self._x25519_kyber_768_private_key = X25519Kyber768Draft00KeyExchange() + public_key = self._x25519_kyber_768_private_key.public_key() + shared_key = self._x25519_kyber_768_private_key.exchange( + peer_public_key + ) + group_kx = Group.X25519KYBER768DRAFT00 + break elif key_share[0] == Group.SECP256R1: self._ec_p256_private_key = ECDHP256KeyExchange() public_key = self._ec_p256_private_key.public_key() diff --git a/src/agreement.rs b/src/agreement.rs index 73f22db80..92d1ac6ba 100644 --- a/src/agreement.rs +++ b/src/agreement.rs @@ -1,10 +1,33 @@ use aws_lc_rs::{agreement, error}; +use aws_lc_rs::kem; +use aws_lc_rs::unstable::kem::{get_algorithm, AlgorithmId}; + +use rustls::crypto::{ + SharedSecret, +}; + use pyo3::Python; use pyo3::types::PyBytes; use pyo3::pymethods; use pyo3::pyclass; +const X25519_LEN: usize = 32; +const KYBER_CIPHERTEXT_LEN: usize = 1088; +const X25519_KYBER_COMBINED_PUBKEY_LEN: usize = X25519_LEN + 1184; +const X25519_KYBER_COMBINED_CIPHERTEXT_LEN: usize = X25519_LEN + KYBER_CIPHERTEXT_LEN; +const X25519_KYBER_COMBINED_SHARED_SECRET_LEN: usize = X25519_LEN + 32; + +struct X25519Kyber768CombinedSecret([u8; X25519_KYBER_COMBINED_SHARED_SECRET_LEN]); + +impl X25519Kyber768CombinedSecret { + fn combine(x25519: SharedSecret, kyber: kem::SharedSecret) -> Self { + let mut out = X25519Kyber768CombinedSecret([0u8; X25519_KYBER_COMBINED_SHARED_SECRET_LEN]); + out.0[..X25519_LEN].copy_from_slice(x25519.secret_bytes()); + out.0[X25519_LEN..].copy_from_slice(kyber.as_ref()); + out + } +} #[pyclass(module = "qh3._hazmat")] pub struct X25519KeyExchange { @@ -28,6 +51,76 @@ pub struct ECDHP521KeyExchange { private: agreement::PrivateKey, } +#[pyclass(module = "qh3._hazmat")] +pub struct X25519Kyber768Draft00KeyExchange { + x25519_private: agreement::PrivateKey, + kyber768_decapsulation_key: kem::DecapsulationKey, +} + +#[pymethods] +impl X25519Kyber768Draft00KeyExchange { + #[new] + pub fn py_new() -> Self { + X25519Kyber768Draft00KeyExchange { + x25519_private: agreement::PrivateKey::generate(&agreement::X25519).expect("FAILURE"), + kyber768_decapsulation_key: kem::DecapsulationKey::generate(get_algorithm(AlgorithmId::Kyber768_R3).expect("Kyber768_R3 not available")).expect("FAILURE") + } + } + + pub fn public_key<'a>(&self, py: Python<'a>) -> &'a PyBytes { + let kyber_pub = self.kyber768_decapsulation_key + .encapsulation_key() + .expect("FAILURE"); + + let mut combined_pub_key = Vec::with_capacity(X25519_KYBER_COMBINED_PUBKEY_LEN); + + combined_pub_key.extend_from_slice(self.x25519_private.compute_public_key().unwrap().as_ref()); + combined_pub_key.extend_from_slice(kyber_pub.key_bytes().unwrap().as_ref()); + + return PyBytes::new( + py, + &combined_pub_key.as_ref() + ); + } + + pub fn exchange<'a>(&self, py: Python<'a>, peer_public_key: &PyBytes) -> &'a PyBytes { + let cipher_text = peer_public_key.as_bytes(); + + if cipher_text.len() != X25519_KYBER_COMBINED_CIPHERTEXT_LEN { + return PyBytes::new(py, &[]); + } + + let (x25519, kyber) = cipher_text.split_at(X25519_LEN); + + let x25519_peer_public_key = agreement::UnparsedPublicKey::new(&agreement::X25519, x25519); + + let x25519_secret = agreement::agree( + &self.x25519_private, + &x25519_peer_public_key, + error::Unspecified, + |_key_material| { + return Ok(_key_material.to_vec()) + }, + ).expect("FAILURE"); + + let kyber_secret = self.kyber768_decapsulation_key + .decapsulate(kyber.into()) + .expect("FAILURE"); + + let combined_secret = X25519Kyber768CombinedSecret::combine( + SharedSecret::from(&x25519_secret[..]), + kyber_secret, + ); + + let key_material = SharedSecret::from(&combined_secret.0[..]); + + return PyBytes::new( + py, + &key_material.secret_bytes() + ); + } +} + #[pymethods] impl X25519KeyExchange { diff --git a/src/lib.rs b/src/lib.rs index 7cdb862d2..85b99cdf8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,7 @@ pub use self::aead::{AeadChaCha20Poly1305, AeadAes128Gcm, AeadAes256Gcm}; pub use self::certificate::{ServerVerifier, Certificate, SelfSignedCertificateError, InvalidNameCertificateError, ExpiredCertificateError, UnacceptableCertificateError}; pub use self::rsa::{Rsa}; pub use self::private_key::{RsaPrivateKey, DsaPrivateKey, Ed25519PrivateKey, EcPrivateKey, verify_with_public_key, SignatureError}; -pub use self::agreement::{X25519KeyExchange, ECDHP256KeyExchange, ECDHP384KeyExchange, ECDHP521KeyExchange}; +pub use self::agreement::{X25519KeyExchange, ECDHP256KeyExchange, ECDHP384KeyExchange, ECDHP521KeyExchange, X25519Kyber768Draft00KeyExchange}; pub use self::pkcs8::{PrivateKeyInfo, KeyType}; pub use self::hpk::{QUICHeaderProtection}; pub use self::ocsp::{OCSPResponse, OCSPCertStatus, OCSPResponseStatus, ReasonFlags, OCSPRequest}; @@ -63,6 +63,7 @@ fn _hazmat(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; // General Crypto Error m.add("CryptoError", py.get_type::())?; // Niquests OCSP helper diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 42eb4f470..a34e999ca 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import binascii import contextlib diff --git a/tests/test_buffer.py b/tests/test_buffer.py index bf6d44803..c38b07574 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase from qh3.buffer import Buffer, BufferReadError, BufferWriteError, size_uint_var diff --git a/tests/test_connection.py b/tests/test_connection.py index 5c1418790..55aebbe53 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import binascii import contextlib import io @@ -11,20 +13,22 @@ from qh3.quic.configuration import QuicConfiguration from qh3.quic.connection import ( STREAM_COUNT_MAX, + MAX_PENDING_CRYPTO, NetworkAddress, QuicConnection, QuicConnectionError, QuicNetworkPath, - QuicReceiveContext, + QuicReceiveContext, MAX_LOCAL_CHALLENGES, ) from qh3.quic.crypto import CryptoPair from qh3.quic.logger import QuicLogger from qh3.quic.packet import ( - PACKET_TYPE_INITIAL, QuicErrorCode, QuicFrameType, + QuicPacketType, QuicProtocolVersion, QuicTransportParameters, + QuicVersionInformation, encode_quic_retry, encode_quic_version_negotiation, push_quic_transport_parameters, @@ -41,8 +45,21 @@ ) CLIENT_ADDR = ("1.2.3.4", 1234) +CLIENT_HANDSHAKE_DATAGRAM_SIZES = [1280] SERVER_ADDR = ("2.3.4.5", 4433) +SERVER_INITIAL_DATAGRAM_SIZES = [1280, 1280, 986] + +HANDSHAKE_COMPLETED_EVENTS = [ + events.HandshakeCompleted, + events.ConnectionIdIssued, + events.ConnectionIdIssued, + events.ConnectionIdIssued, + events.ConnectionIdIssued, + events.ConnectionIdIssued, + events.ConnectionIdIssued, + events.ConnectionIdIssued, +] TICK = 0.05 # seconds @@ -65,6 +82,7 @@ def client_receive_context(client, epoch=tls.Epoch.ONE_RTT): network_path=client._network_paths[0], quic_logger_frames=[], time=time.time(), + version=None, ) @@ -85,15 +103,45 @@ def create_standalone_client(self, **client_options): # kick-off handshake client.connect(SERVER_ADDR, now=time.time()) - self.assertEqual(drop(client), 1) + self.assertEqual(drop(client), 2) return client +def create_standalone_server(self, original_destination_connection_id=bytes(8)): + server_configuration = QuicConfiguration(is_client=False, quic_logger=QuicLogger()) + server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) + + server = QuicConnection( + configuration=server_configuration, + original_destination_connection_id=original_destination_connection_id, + ) + server._ack_delay = 0 + + return server + + def datagram_sizes(items: List[Tuple[bytes, NetworkAddress]]) -> List[int]: return [len(x[0]) for x in items] +def new_connection_id( + *, + sequence_number: int, + retire_prior_to: int = 0, + connection_id: bytes = bytes(8), + capacity: int = 100, +): + buf = Buffer(capacity=capacity) + buf.push_uint_var(sequence_number) + buf.push_uint_var(retire_prior_to) + buf.push_uint_var(len(connection_id)) + buf.push_bytes(connection_id) + buf.push_bytes(bytes(16)) # stateless reset token + buf.seek(0) + return buf + + @contextlib.contextmanager def client_and_server( client_kwargs={}, @@ -188,6 +236,30 @@ def transfer(sender, receiver): class QuicConnectionTest(TestCase): + def assertEvents(self, connection: QuicConnection, expected: list): + types = [] + while True: + event = connection.next_event() + if event is not None: + types.append(type(event)) + else: + break + + self.assertListEqual(types, expected) + + def assertPacketDropped(self, connection: QuicConnection, trigger: str): + log = connection.configuration.quic_logger.to_dict() + found_trigger = None + for event in log["traces"][0]["events"]: + if event["name"] == "transport:packet_dropped": + found_trigger = event["data"]["trigger"] + break + self.assertEqual(found_trigger, trigger) + + def assertSentPackets(self, connection: QuicConnection, expected: List[int]): + counts = [len(space.sent_packets) for space in connection._loss.spaces] + self.assertEqual(counts, expected) + def check_handshake(self, client, server, alpn_protocol=None): """ Check handshake completed. @@ -348,6 +420,59 @@ def test_connect_with_cipher_suite_chacha20(self): tls.CipherSuite.CHACHA20_POLY1305_SHA256, ) + def test_connect_without_loss(self): + """ + Check connection is established in the absence of loss. + """ + with client_and_server(handshake=False) as (client, server): + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertEqual(client.get_timer(), 0.2) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + # server receives INITIAL, sends INITIAL + HANDSHAKE + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) + self.assertAlmostEqual(server.get_timer(), 0.25) + self.assertSentPackets(server, [2, 2, 0]) + self.assertEvents(server, [events.ProtocolNegotiated]) + + # handshake continues normally + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + client.receive_datagram(items[2][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) + self.assertAlmostEqual(client.get_timer(), 0.425) + self.assertSentPackets(client, [0, 1, 1]) + self.assertEvents( + client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS + ) + + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 0.425) + self.assertSentPackets(server, [0, 0, 1]) + self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) + + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 60.2) # idle timeout + self.assertSentPackets(client, [0, 0, 1]) + self.assertEvents(client, []) + def test_connect_with_loss_1(self): """ Check connection is established even in the client's INITIAL is lost. @@ -355,72 +480,63 @@ def test_connect_with_loss_1(self): The client's PTO fires, triggering retransmission. """ - client_configuration = QuicConfiguration(is_client=True) - client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) - - client = QuicConnection(configuration=client_configuration) - client._ack_delay = 0 - - server_configuration = QuicConfiguration(is_client=False) - server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) - - server = QuicConnection( - configuration=server_configuration, - original_destination_connection_id=client.original_destination_connection_id, - ) - server._ack_delay = 0 - - # client sends INITIAL - now = 0.0 - client.connect(SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertEqual(client.get_timer(), 0.2) - - # INITIAL is lost - now = client.get_timer() - client.handle_timer(now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertAlmostEqual(client.get_timer(), 0.6) - - # server receives INITIAL, sends INITIAL + HANDSHAKE - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 1019]) - self.assertAlmostEqual(server.get_timer(), 0.45) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) - self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) - self.assertIsNone(server.next_event()) - - # handshake continues normally - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - client.receive_datagram(items[1][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [360]) - self.assertAlmostEqual(client.get_timer(), 0.625) - self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) - self.assertEqual(type(client.next_event()), events.HandshakeCompleted) - self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) - - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [229]) - self.assertAlmostEqual(server.get_timer(), 0.625) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 0) - self.assertEqual(type(server.next_event()), events.HandshakeCompleted) - self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) - - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [32]) - self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout + with client_and_server(handshake=False) as (client, server): + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertEqual(client.get_timer(), 0.2) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + # INITIAL is lost and retransmitted + now = client.get_timer() + client.handle_timer(now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertAlmostEqual(client.get_timer(), 0.6) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + # server receives INITIAL, sends INITIAL + HANDSHAKE + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) + self.assertAlmostEqual(server.get_timer(), 0.45) + self.assertSentPackets(server, [2, 2, 0]) + self.assertEvents(server, [events.ProtocolNegotiated]) + + # handshake continues normally + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + client.receive_datagram(items[2][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) + self.assertAlmostEqual(client.get_timer(), 0.625) + self.assertSentPackets(client, [0, 1, 1]) + self.assertEvents( + client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS + ) + + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 0.625) + self.assertSentPackets(server, [0, 0, 1]) + self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) + + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout + self.assertSentPackets(client, [0, 0, 1]) + self.assertEvents(client, []) def test_connect_with_loss_2(self): """ @@ -430,82 +546,76 @@ def test_connect_with_loss_2(self): and decides to retransmit its own CRYPTO to speedup handshake completion. """ - client_configuration = QuicConfiguration(is_client=True) - client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) - - client = QuicConnection(configuration=client_configuration) - client._ack_delay = 0 - - server_configuration = QuicConfiguration(is_client=False) - server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) - - server = QuicConnection( - configuration=server_configuration, - original_destination_connection_id=client.original_destination_connection_id, - ) - server._ack_delay = 0 - - # client sends INITIAL - now = 0.0 - client.connect(SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertEqual(client.get_timer(), 0.2) - - # server receives INITIAL, sends INITIAL + HANDSHAKE but first datagram is lost - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 1019]) - self.assertEqual(server.get_timer(), 0.25) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) - self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) - self.assertIsNone(server.next_event()) - - # client only receives second datagram, retransmits INITIAL - now += TICK - client.receive_datagram(items[1][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertAlmostEqual(client.get_timer(), 0.3) - self.assertIsNone(client.next_event()) - - # server receives duplicate INITIAL, retransmits INITIAL + HANDSHAKE - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 1019]) - self.assertAlmostEqual(server.get_timer(), 0.35) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) - - # handshake continues normally - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - client.receive_datagram(items[1][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [360]) - self.assertAlmostEqual(client.get_timer(), 0.525) - self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) - self.assertEqual(type(client.next_event()), events.HandshakeCompleted) - self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) - - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [229]) - self.assertAlmostEqual(server.get_timer(), 0.525) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 0) - self.assertEqual(type(server.next_event()), events.HandshakeCompleted) - self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) - - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [32]) - self.assertAlmostEqual(client.get_timer(), 60.3) # idle timeout + with client_and_server(handshake=False) as (client, server): + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertEqual(client.get_timer(), 0.2) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + # server receives INITIAL, sends INITIAL + HANDSHAKE but first datagram + # is lost + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) + self.assertEqual(server.get_timer(), 0.25) + self.assertSentPackets(server, [2, 2, 0]) + self.assertEvents(server, [events.ProtocolNegotiated]) + + # client only receives second datagram, retransmits INITIAL + now += TICK + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertAlmostEqual(client.get_timer(), 0.3) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + self.assertPacketDropped(client, "key_unavailable") + + # server receives duplicate INITIAL, retransmits INITIAL + HANDSHAKE + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280, 890]) + # self.assertAlmostEqual(server.get_timer(), 0.35) + self.assertSentPackets(server, [1, 2, 0]) + self.assertEvents(server, []) + + # handshake continues normally + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + client.receive_datagram(items[2][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) + # self.assertAlmostEqual(client.get_timer(), 0.525) + self.assertSentPackets(client, [0, 1, 1]) + self.assertEvents( + client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS + ) + + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + # self.assertAlmostEqual(server.get_timer(), 0.525) + self.assertSentPackets(server, [0, 0, 1]) + self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) + + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 60.3) # idle timeout + self.assertSentPackets(client, [0, 0, 1]) + self.assertEvents(client, []) def test_connect_with_loss_3(self): """ @@ -516,261 +626,253 @@ def test_connect_with_loss_3(self): CRYPTO to speedup handshake completion. """ - client_configuration = QuicConfiguration(is_client=True) - client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) - - client = QuicConnection(configuration=client_configuration) - client._ack_delay = 0 - - server_configuration = QuicConfiguration(is_client=False) - server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) - - server = QuicConnection( - configuration=server_configuration, - original_destination_connection_id=client.original_destination_connection_id, - ) - server._ack_delay = 0 - - # client sends INITIAL - now = 0.0 - client.connect(SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertEqual(client.get_timer(), 0.2) - - # server receives INITIAL, sends INITIAL + HANDSHAKE - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 1019]) - self.assertEqual(server.get_timer(), 0.25) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) - self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) - self.assertIsNone(server.next_event()) - - # INITIAL + HANDSHAKE are lost, client retransmits INITIAL - now = client.get_timer() - client.handle_timer(now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertAlmostEqual(client.get_timer(), 0.6) - self.assertIsNone(client.next_event()) - - # server receives duplicate INITIAL, retransmits INITIAL + HANDSHAKE - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 1019]) - self.assertEqual(server.get_timer(), 0.45) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) - - # handshake continues normally - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - client.receive_datagram(items[1][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [360]) - self.assertAlmostEqual(client.get_timer(), 0.625) - self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) - self.assertEqual(type(client.next_event()), events.HandshakeCompleted) - self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) - - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [229]) - self.assertAlmostEqual(server.get_timer(), 0.625) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 0) - self.assertEqual(type(server.next_event()), events.HandshakeCompleted) - self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) - - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [32]) - self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout + with client_and_server(handshake=False) as (client, server): + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertEqual(client.get_timer(), 0.2) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + # server receives INITIAL, sends INITIAL + HANDSHAKE + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) + self.assertEqual(server.get_timer(), 0.25) + self.assertSentPackets(server, [2, 2, 0]) + self.assertEvents(server, [events.ProtocolNegotiated]) + + # INITIAL + HANDSHAKE are lost, client retransmits INITIAL + now = client.get_timer() + client.handle_timer(now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertAlmostEqual(client.get_timer(), 0.6) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + # server receives duplicate INITIAL, retransmits INITIAL + HANDSHAKE + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) + self.assertEqual(server.get_timer(), 0.45) + self.assertSentPackets(server, [2, 2, 0]) + self.assertEvents(server, []) + + # handshake continues normally + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + client.receive_datagram(items[2][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) + self.assertGreaterEqual(client.get_timer(), 0.5) + self.assertLessEqual(client.get_timer(), 0.63) + self.assertSentPackets(client, [0, 1, 1]) + self.assertEvents( + client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS + ) + + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 0.625) + self.assertSentPackets(server, [0, 0, 1]) + self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) + + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout + self.assertSentPackets(client, [0, 0, 1]) + self.assertEvents(client, []) def test_connect_with_loss_4(self): """ Check connection is established even in the server's HANDSHAKE is lost. """ - client_configuration = QuicConfiguration(is_client=True) - client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) - - client = QuicConnection(configuration=client_configuration) - client._ack_delay = 0 - - server_configuration = QuicConfiguration(is_client=False) - server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) - - server = QuicConnection( - configuration=server_configuration, - original_destination_connection_id=client.original_destination_connection_id, - ) - server._ack_delay = 0 - - # client sends INITIAL - now = 0.0 - client.connect(SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertEqual(client.get_timer(), 0.2) - - # server receives INITIAL, sends INITIAL + HANDSHAKE but second datagram is lost - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 1019]) - self.assertEqual(server.get_timer(), 0.25) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) - self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) - self.assertIsNone(server.next_event()) - - # client only receives first datagram and sends ACKS - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [97]) - self.assertAlmostEqual(client.get_timer(), 0.325) - self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) - self.assertIsNone(client.next_event()) - - # client PTO - HANDSHAKE PING - now = client.get_timer() - client.handle_timer(now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [45]) - self.assertAlmostEqual(client.get_timer(), 0.975) - - # server receives PING, discards INITIAL and sends ACK - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [48]) - self.assertAlmostEqual(server.get_timer(), 0.25) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 3) - self.assertIsNone(server.next_event()) - - # ACKs are lost, server retransmits HANDSHAKE - now = server.get_timer() - server.handle_timer(now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 876]) - self.assertAlmostEqual(server.get_timer(), 0.65) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 3) - self.assertIsNone(server.next_event()) - - # handshake continues normally - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - client.receive_datagram(items[1][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [313]) - self.assertAlmostEqual(client.get_timer(), 0.95) - self.assertEqual(type(client.next_event()), events.HandshakeCompleted) - self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) - - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [229]) - self.assertAlmostEqual(server.get_timer(), 0.675) - self.assertEqual(type(server.next_event()), events.HandshakeCompleted) - self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) - - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [32]) - self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout + with client_and_server(handshake=False) as (client, server): + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertEqual(client.get_timer(), 0.2) + self.assertSentPackets(client, [2, 0, 0]) + self.assertEvents(client, []) + + # server receives INITIAL, sends ACK + INITIAL + HANDSHAKE but third + # datagram is lost + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) + self.assertEqual(server.get_timer(), 0.25) + self.assertSentPackets(server, [2, 2, 0]) + self.assertEvents(server, [events.ProtocolNegotiated]) + + # client only receives the first datagram and sends ACKS + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280]) + self.assertAlmostEqual(client.get_timer(), 0.325) + self.assertSentPackets(client, [0, 1, 0]) + self.assertEvents(client, [events.ProtocolNegotiated]) + + # client PTO - HANDSHAKE PING + now = client.get_timer() + client.handle_timer(now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [45]) + self.assertAlmostEqual(client.get_timer(), 0.975) + self.assertSentPackets(client, [0, 2, 0]) + self.assertEvents(client, []) + + # server receives PING, discards INITIAL and sends ACK + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [48]) + self.assertAlmostEqual(server.get_timer(), 0.25) + self.assertSentPackets(server, [0, 3, 0]) + self.assertEvents(server, []) + + # ACKs are lost, server retransmits HANDSHAKE + now = server.get_timer() + server.handle_timer(now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 890]) + self.assertAlmostEqual(server.get_timer(), 0.65) + self.assertSentPackets(server, [0, 3, 0]) + self.assertEvents(server, []) + + # handshake continues normally + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [313]) + self.assertAlmostEqual(client.get_timer(), 0.95) + self.assertSentPackets(client, [0, 3, 1]) + self.assertEvents(client, HANDSHAKE_COMPLETED_EVENTS) + + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 0.675) + self.assertSentPackets(server, [0, 0, 1]) + self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) + + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout + self.assertSentPackets(client, [0, 0, 1]) + self.assertEvents(client, []) def test_connect_with_loss_5(self): """ Check connection is established even in the server's HANDSHAKE_DONE is lost. """ - client_configuration = QuicConfiguration(is_client=True) - client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) - - client = QuicConnection(configuration=client_configuration) - client._ack_delay = 0 - - server_configuration = QuicConfiguration(is_client=False) - server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) - - server = QuicConnection( - configuration=server_configuration, - original_destination_connection_id=client.original_destination_connection_id, - ) - server._ack_delay = 0 - - # client sends INITIAL - now = 0.0 - client.connect(SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280]) - self.assertEqual(client.get_timer(), 0.2) - - # server receives INITIAL, sends INITIAL + HANDSHAKE - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1280, 1019]) - self.assertEqual(server.get_timer(), 0.25) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) - self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) - self.assertIsNone(server.next_event()) - - # client receives INITIAL + HANDSHAKE - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - client.receive_datagram(items[1][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [360]) - self.assertAlmostEqual(client.get_timer(), 0.425) - self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) - self.assertEqual(type(client.next_event()), events.HandshakeCompleted) - self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) - - # server completes handshake, but HANDSHAKE_DONE is lost - now += TICK - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [229]) - self.assertAlmostEqual(server.get_timer(), 0.425) - self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) - self.assertEqual(len(server._loss.spaces[1].sent_packets), 0) - self.assertEqual(type(server.next_event()), events.HandshakeCompleted) - self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) - - # server PTO - 1-RTT PING - now = server.get_timer() - server.handle_timer(now=now) - items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [29]) - self.assertAlmostEqual(server.get_timer(), 0.975) - - # client receives PING, sends ACK - now += TICK - client.receive_datagram(items[0][0], SERVER_ADDR, now=now) - items = client.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [32]) - self.assertAlmostEqual(client.get_timer(), 0.425) - - # server receives ACK, retransmits HANDSHAKE_DONE - now += TICK - self.assertFalse(server._handshake_done_pending) - server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) - self.assertTrue(server._handshake_done_pending) - items = server.datagrams_to_send(now=now) - self.assertFalse(server._handshake_done_pending) - self.assertEqual(datagram_sizes(items), [224]) + with client_and_server(handshake=False) as (client, server): + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1280]) + self.assertEqual(client.get_timer(), 0.2) + + # server receives INITIAL, sends INITIAL + HANDSHAKE + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + server.receive_datagram(items[1][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) + self.assertEqual(server.get_timer(), 0.25) + self.assertSentPackets(server, [2, 2, 0]) + self.assertEvents(server, [events.ProtocolNegotiated]) + + # client receives INITIAL + HANDSHAKE + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + client.receive_datagram(items[2][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) + self.assertAlmostEqual(client.get_timer(), 0.425) + self.assertSentPackets(client, [0, 1, 1]) + self.assertEvents( + client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS + ) + + # server completes handshake, but HANDSHAKE_DONE is lost + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 0.425) + self.assertSentPackets(server, [0, 0, 1]) + self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) + + # server PTO - 1-RTT PING + now = server.get_timer() + server.handle_timer(now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [29]) + self.assertAlmostEqual(server.get_timer(), 0.975) + self.assertSentPackets(server, [0, 0, 2]) + self.assertEvents(server, []) + + # client receives PING, sends ACK + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 0.425) + self.assertSentPackets(client, [0, 1, 2]) + self.assertEvents(client, []) + + # server receives ACK, retransmits HANDSHAKE_DONE + now += TICK + self.assertFalse(server._handshake_done_pending) + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + self.assertTrue(server._handshake_done_pending) + items = server.datagrams_to_send(now=now) + self.assertFalse(server._handshake_done_pending) + self.assertEqual(datagram_sizes(items), [224]) + self.assertAlmostEqual(server.get_timer(), 0.7625) + self.assertSentPackets(server, [0, 0, 1]) + # FIXME: the server re-emits the ConnectionIdIssued events + self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS[1:]) + + now += TICK + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 0.425) + self.assertSentPackets(client, [0, 0, 3]) + self.assertEvents(client, []) + + now += TICK + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), []) + self.assertAlmostEqual(server.get_timer(), 60.625) # idle timeout + self.assertSentPackets(server, [0, 0, 0]) + self.assertEvents(server, []) def test_connect_with_no_transport_parameters(self): def patch(client): @@ -791,6 +893,78 @@ def patched_initialize(peer_cid: bytes): "No QUIC transport parameters received", ) + def test_connect_with_compatible_version_negotiation_1(self): + """ + The client only supports version 1. + + The server sets the Negotiated Version to version 1. + """ + with client_and_server( + client_options={ + "supported_versions": [QuicProtocolVersion.VERSION_1], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_1) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_1) + + def test_connect_with_compatible_version_negotiation_1_to_2(self): + """ + The client originally connects using version 1 but prefers version 2. + + The server sets the Negotiated Version to version 2. + """ + with client_and_server( + client_options={ + "original_version": QuicProtocolVersion.VERSION_1, + "supported_versions": [ + QuicProtocolVersion.VERSION_2, + QuicProtocolVersion.VERSION_1, + ], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_2) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_2) + + def test_connect_with_compatible_version_negotiation_2(self): + """ + The client only supports version 2. + + The server sets the Negotiated Version to version 2. + """ + with client_and_server( + client_options={ + "supported_versions": [QuicProtocolVersion.VERSION_2], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_2) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_2) + + def test_connect_with_compatible_version_negotiation_2_to_1(self): + """ + The client originally connects using version 2 but prefers version 1. + + The server sets the Negotiated Version to version 1. + """ + with client_and_server( + client_options={ + "original_version": QuicProtocolVersion.VERSION_2, + "supported_versions": [ + QuicProtocolVersion.VERSION_1, + QuicProtocolVersion.VERSION_2, + ], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_1) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_1) + def test_connect_with_quantum_readiness(self): with client_and_server(client_options={"quantum_readiness_test": True}) as ( client, @@ -834,7 +1008,7 @@ def save_session_ticket(ticket): stream_id = client.get_next_available_stream_id() client.send_stream_data(stream_id, b"hello") - self.assertEqual(roundtrip(client, server), (2, 1)) + self.assertEqual(roundtrip(client, server), (2, 2)) event = server.next_event() self.assertEqual(type(event), events.ProtocolNegotiated) @@ -1011,7 +1185,7 @@ def test_datagram_frame_2(self): client.send_datagram_frame(payload) # client can only 11 datagrams are sent due to congestion control - self.assertEqual(transfer(client, server), 11) + self.assertEqual(transfer(client, server), 12) for i in range(11): event = server.next_event() self.assertEqual(type(event), events.DatagramFrameReceived) @@ -1021,7 +1195,7 @@ def test_datagram_frame_2(self): self.assertEqual(transfer(server, client), 1) # client sends remaining datagrams - self.assertEqual(transfer(client, server), 9) + self.assertEqual(transfer(client, server), 8) for i in range(9): event = server.next_event() self.assertEqual(type(event), events.DatagramFrameReceived) @@ -1095,7 +1269,7 @@ def encrypt_packet(plain_header, plain_payload, packet_number): crypto.encrypt_packet = encrypt_packet - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) @@ -1124,7 +1298,7 @@ def test_receive_datagram_wrong_version(self): crypto.setup_initial( client._peer_cid.cid, is_client=False, version=client._version ) - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) @@ -1132,6 +1306,8 @@ def test_receive_datagram_wrong_version(self): client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) self.assertEqual(drop(client), 0) + self.assertPacketDropped(client, "unsupported_version") + def test_receive_datagram_retry(self): client = create_standalone_client(self) @@ -1146,7 +1322,7 @@ def test_receive_datagram_retry(self): SERVER_ADDR, now=time.time(), ) - self.assertEqual(drop(client), 1) + self.assertEqual(drop(client), 2) def test_receive_datagram_retry_wrong_destination_cid(self): client = create_standalone_client(self) @@ -1163,6 +1339,7 @@ def test_receive_datagram_retry_wrong_destination_cid(self): now=time.time(), ) self.assertEqual(drop(client), 0) + self.assertPacketDropped(client, "unknown_connection_id") def test_receive_datagram_retry_wrong_integrity_tag(self): client = create_standalone_client(self) @@ -1484,13 +1661,7 @@ def test_handle_max_streams_uni_frame(self): def test_handle_new_connection_id_duplicate(self): with client_and_server() as (client, server): - buf = Buffer(capacity=100) - buf.push_uint_var(7) # sequence_number - buf.push_uint_var(0) # retire_prior_to - buf.push_uint_var(8) - buf.push_bytes(bytes(8)) - buf.push_bytes(bytes(16)) - buf.seek(0) + buf = new_connection_id(sequence_number=7) # client receives NEW_CONNECTION_ID client._handle_new_connection_id_frame( @@ -1506,13 +1677,7 @@ def test_handle_new_connection_id_duplicate(self): def test_handle_new_connection_id_over_limit(self): with client_and_server() as (client, server): - buf = Buffer(capacity=100) - buf.push_uint_var(8) # sequence_number - buf.push_uint_var(0) # retire_prior_to - buf.push_uint_var(8) - buf.push_bytes(bytes(8)) - buf.push_bytes(bytes(16)) - buf.seek(0) + buf = new_connection_id(sequence_number=8) # client receives NEW_CONNECTION_ID with self.assertRaises(QuicConnectionError) as cm: @@ -1531,13 +1696,7 @@ def test_handle_new_connection_id_over_limit(self): def test_handle_new_connection_id_with_retire_prior_to(self): with client_and_server() as (client, server): - buf = Buffer(capacity=42) - buf.push_uint_var(8) # sequence_number - buf.push_uint_var(2) # retire_prior_to - buf.push_uint_var(8) - buf.push_bytes(bytes(8)) - buf.push_bytes(bytes(16)) - buf.seek(0) + buf = new_connection_id(sequence_number=8, retire_prior_to=2, capacity=42) # client receives NEW_CONNECTION_ID client._handle_new_connection_id_frame( @@ -1551,15 +1710,67 @@ def test_handle_new_connection_id_with_retire_prior_to(self): sequence_numbers(client._peer_cid_available), [3, 4, 5, 6, 7, 8] ) + def test_handle_new_connection_id_with_retire_prior_to_lower(self): + with client_and_server() as (client, server): + buf = new_connection_id(sequence_number=80, retire_prior_to=80) + # client receives NEW_CONNECTION_ID + client._handle_new_connection_id_frame( + client_receive_context(client), + QuicFrameType.NEW_CONNECTION_ID, + buf, + ) + self.assertEqual(client._peer_cid.sequence_number, 80) + self.assertEqual(sequence_numbers(client._peer_cid_available), []) + buf = new_connection_id(sequence_number=30, retire_prior_to=30) + # client receives NEW_CONNECTION_ID + client._handle_new_connection_id_frame( + client_receive_context(client), + QuicFrameType.NEW_CONNECTION_ID, + buf, + ) + self.assertEqual(client._peer_cid.sequence_number, 80) + self.assertEqual(sequence_numbers(client._peer_cid_available), []) + + def test_handle_excessive_new_connection_id_retires(self): + with client_and_server() as (client, server): + for i in range(25): + sequence_number = 8 + i + buf = new_connection_id( + sequence_number=sequence_number, retire_prior_to=sequence_number + ) + # client receives NEW_CONNECTION_ID + client._handle_new_connection_id_frame( + client_receive_context(client), + QuicFrameType.NEW_CONNECTION_ID, + buf, + ) + # So far, so good! We should be at the (default) limit of 4*8 pending + # retirements. + self.assertEqual(len(client._retire_connection_ids), 32) + # Now we will go one too many! + sequence_number = 8 + 25 + buf = new_connection_id( + sequence_number=sequence_number, retire_prior_to=sequence_number + ) + with self.assertRaises(QuicConnectionError) as cm: + client._handle_new_connection_id_frame( + client_receive_context(client), + QuicFrameType.NEW_CONNECTION_ID, + buf, + ) + self.assertEqual( + cm.exception.error_code, QuicErrorCode.CONNECTION_ID_LIMIT_ERROR + ) + self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_CONNECTION_ID) + self.assertEqual( + cm.exception.reason_phrase, "Too many pending retired connection IDs" + ) + def test_handle_new_connection_id_with_connection_id_invalid(self): with client_and_server() as (client, server): - buf = Buffer(capacity=100) - buf.push_uint_var(8) # sequence_number - buf.push_uint_var(2) # retire_prior_to - buf.push_uint_var(21) - buf.push_bytes(bytes(21)) - buf.push_bytes(bytes(16)) - buf.seek(0) + buf = new_connection_id( + sequence_number=8, retire_prior_to=2, connection_id=bytes(21) + ) # client receives NEW_CONNECTION_ID with self.assertRaises(QuicConnectionError) as cm: @@ -1580,13 +1791,7 @@ def test_handle_new_connection_id_with_connection_id_invalid(self): def test_handle_new_connection_id_with_retire_prior_to_invalid(self): with client_and_server() as (client, server): - buf = Buffer(capacity=100) - buf.push_uint_var(8) # sequence_number - buf.push_uint_var(9) # retire_prior_to - buf.push_uint_var(8) - buf.push_bytes(bytes(8)) - buf.push_bytes(bytes(16)) - buf.seek(0) + buf = new_connection_id(sequence_number=8, retire_prior_to=9) # client receives NEW_CONNECTION_ID with self.assertRaises(QuicConnectionError) as cm: @@ -1655,6 +1860,44 @@ def test_handle_path_challenge_frame(self): self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) self.assertTrue(server._network_paths[1].is_validated) + def test_handle_path_challenge_response_on_different_path(self): + with client_and_server() as (client, server): + # client changes address and sends some data + client.send_stream_data(0, b"01234567") + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) + # check paths + self.assertEqual(len(server._network_paths), 2) + self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) + self.assertFalse(server._network_paths[0].is_validated) + self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) + self.assertTrue(server._network_paths[1].is_validated) + # server sends PATH_CHALLENGE and receives PATH_RESPONSE on the 1234 + # path instead of the expected 2345 path. + for data, addr in server.datagrams_to_send(now=time.time()): + client.receive_datagram(data, SERVER_ADDR, now=time.time()) + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, ("1.2.3.4", 1234), now=time.time()) + # check paths; note that the order is backwards from the prior test + # as receiving on 1234 promotes it to first in the list + self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 1234)) + self.assertTrue(server._network_paths[0].is_validated) + self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 2345)) + self.assertTrue(server._network_paths[1].is_validated) + + def test_local_path_challenges_are_bounded(self): + with client_and_server() as (client, server): + for i in range(MAX_LOCAL_CHALLENGES + 2): + server._add_local_challenge( + int.to_bytes(i, 8, "big"), QuicNetworkPath(f"1.2.3.{i}") + ) + self.assertEqual(len(server._local_challenges), MAX_LOCAL_CHALLENGES) + for i in range(2, MAX_LOCAL_CHALLENGES + 2): + self.assertEqual( + server._local_challenges[int.to_bytes(i, 8, "big")].addr, + f"1.2.3.{i}", + ) + def test_handle_path_response_frame_bad(self): with client_and_server() as (client, server): # server receives unsolicited PATH_RESPONSE @@ -2255,33 +2498,52 @@ def test_parse_transport_parameters_with_bad_initial_source_connection_id(self): cm.exception.reason_phrase, "initial_source_connection_id does not match" ) - def test_parse_transport_parameters_with_server_only_parameter(self): - server_configuration = QuicConfiguration( - is_client=False, quic_logger=QuicLogger() + def test_parse_transport_parameters_with_bad_version_information_1(self): + server = create_standalone_server(self) + data = encode_transport_parameters( + QuicTransportParameters( + version_information=QuicVersionInformation( + chosen_version=QuicProtocolVersion.VERSION_1, + available_versions=[QuicProtocolVersion.VERSION_2], + ) + ) ) - server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) - - server = QuicConnection( - configuration=server_configuration, - original_destination_connection_id=bytes(8), + with self.assertRaises(QuicConnectionError) as cm: + server._parse_transport_parameters(data) + self.assertEqual( + cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) - for active_connection_id_limit in [0, 1]: - data = encode_transport_parameters( - QuicTransportParameters( - active_connection_id_limit=active_connection_id_limit, - original_destination_connection_id=bytes(8), + self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) + self.assertEqual( + cm.exception.reason_phrase, + "version_information's chosen_version is not included in " + "available_versions", + ) + + def test_parse_transport_parameters_with_bad_version_information_2(self): + server = create_standalone_server(self) + data = encode_transport_parameters( + QuicTransportParameters( + version_information=QuicVersionInformation( + chosen_version=QuicProtocolVersion.VERSION_1, + available_versions=[ + QuicProtocolVersion.VERSION_1, + QuicProtocolVersion.VERSION_2, + ], ) ) - with self.assertRaises(QuicConnectionError) as cm: - server._parse_transport_parameters(data) - self.assertEqual( - cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR - ) - self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) - self.assertEqual( - cm.exception.reason_phrase, - "original_destination_connection_id is not allowed for clients", - ) + ) + server._crypto_packet_version = QuicProtocolVersion.VERSION_2 + with self.assertRaises(QuicConnectionError) as cm: + server._parse_transport_parameters(data) + self.assertEqual( + cm.exception.error_code, QuicErrorCode.VERSION_NEGOTIATION_ERROR + ) + self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) + self.assertEqual( + cm.exception.reason_phrase, + "version_information's chosen_version does not match the version in use", + ) def test_payload_received_empty(self): with client_and_server() as (client, server): @@ -2301,12 +2563,23 @@ def test_payload_received_padding_only(self): self.assertFalse(is_ack_eliciting) self.assertTrue(is_probing) + def test_payload_received_malformed_frame_type(self): + with client_and_server() as (client, server): + # client receives a malformed frame type + with self.assertRaises(QuicConnectionError) as cm: + client._payload_received(client_receive_context(client), b"\xff") + self.assertEqual( + cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR + ) + self.assertEqual(cm.exception.frame_type, None) + self.assertEqual(cm.exception.reason_phrase, "Malformed frame type") + def test_payload_received_unknown_frame(self): with client_and_server() as (client, server): # client receives unknown frame with self.assertRaises(QuicConnectionError) as cm: client._payload_received(client_receive_context(client), b"\x1f") - self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR) self.assertEqual(cm.exception.frame_type, 0x1F) self.assertEqual(cm.exception.reason_phrase, "Unknown frame type") @@ -2338,14 +2611,15 @@ def test_send_max_data_blocked_by_cc(self): with client_and_server() as (client, server): # check congestion control self.assertEqual(client._loss.bytes_in_flight, 0) - self.assertEqual(client._loss.congestion_window, 14303) + self.assertGreaterEqual(client._loss.congestion_window, 13530) + self.assertLessEqual(client._loss.congestion_window, 16000) # artificially raise received data counter client._local_max_data_used = client._local_max_data self.assertEqual(server._remote_max_data, 1048576) # artificially raise bytes in flight - client._loss._cc.bytes_in_flight = 14303 + client._loss._cc.bytes_in_flight = client._loss.congestion_window # MAX_DATA is not sent due to congestion control self.assertEqual(drop(client), 0) @@ -2702,6 +2976,21 @@ def test_version_negotiation_ignore(self): ) self.assertEqual(drop(client), 0) + def test_version_negotiation_ignore_server(self): + server = create_standalone_server(self) + + # Servers do not expect version negotiation packets. + server.receive_datagram( + encode_quic_version_negotiation( + source_cid=server._peer_cid.cid, + destination_cid=server.host_cid, + supported_versions=[QuicProtocolVersion.VERSION_1], + ), + CLIENT_ADDR, + now=time.time(), + ) + self.assertPacketDropped(server, "unexpected_packet") + def test_version_negotiation_ok(self): client = create_standalone_client(self) @@ -2728,7 +3017,7 @@ def test_write_connection_close_early(self): ) crypto = CryptoPair() crypto.setup_initial(client.host_cid, is_client=True, version=client._version) - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) client._write_connection_close_frame( builder=builder, epoch=tls.Epoch.INITIAL, @@ -2751,6 +3040,34 @@ def test_write_connection_close_early(self): ], ) + def test_excessive_crypto_buffering(self): + with client_and_server() as (client, server): + # Client receives data that causes more than 512K of buffering; note that + # because the stream buffer is a single buffer and not a set of fragments, + # the total buffering size depends not on how much data is received, but + # how much buffering is needed. We send fragments of only 100 bytes + # at offsets 10000, 20000, 30000 etc. + highest_good_offset = 0 + with self.assertRaises(QuicConnectionError) as cm: + # We don't start at zero as we want to force buffering, not cause + # a TLS error. + for offset in range(10000, 1000000, 10000): + client._handle_crypto_frame( + client_receive_context(client), + QuicFrameType.CRYPTO, + Buffer( + data=encode_uint_var(offset) + + encode_uint_var(100) + + b"\x00" * 100 + ), + ) + highest_good_offset = offset + self.assertEqual( + cm.exception.error_code, QuicErrorCode.CRYPTO_BUFFER_EXCEEDED + ) + self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) + self.assertEqual(highest_good_offset, (MAX_PENDING_CRYPTO // 10000) * 10000) + class QuicNetworkPathTest(TestCase): def test_can_send(self): diff --git a/tests/test_crypto.py b/tests/test_crypto_v1.py similarity index 96% rename from tests/test_crypto.py rename to tests/test_crypto_v1.py index cfacbab12..c130a4916 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto_v1.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import binascii from unittest import TestCase, skipIf @@ -152,7 +154,11 @@ def test_derive_key_iv_hp(self): secret = binascii.unhexlify( "c00cf151ca5be075ed0ebfb5c80323c42d6b7db67881289af4008f1f6c357aea" ) - key, iv, hp = derive_key_iv_hp(INITIAL_CIPHER_SUITE, secret) + key, iv, hp = derive_key_iv_hp( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=secret, + version=PROTOCOL_VERSION, + ) self.assertEqual(key, binascii.unhexlify("1f369613dd76d5467730efcbe3b1a22d")) self.assertEqual(iv, binascii.unhexlify("fa044b2f42a3fd3b46fb255c")) self.assertEqual(hp, binascii.unhexlify("9f50449e04a0e810283a1e9933adedd2")) @@ -161,7 +167,11 @@ def test_derive_key_iv_hp(self): secret = binascii.unhexlify( "3c199828fd139efd216c155ad844cc81fb82fa8d7446fa7d78be803acdda951b" ) - key, iv, hp = derive_key_iv_hp(INITIAL_CIPHER_SUITE, secret) + key, iv, hp = derive_key_iv_hp( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=secret, + version=PROTOCOL_VERSION, + ) self.assertEqual(key, binascii.unhexlify("cf3a5331653c364c88f0f379b6067e37")) self.assertEqual(iv, binascii.unhexlify("0ac1493ca1905853b0bba03e")) self.assertEqual(hp, binascii.unhexlify("c206b8d9b9f0f37644430b490eeaa314")) @@ -174,7 +184,11 @@ def test_derive_key_iv_hp_chacha20(self): secret = binascii.unhexlify( "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" ) - key, iv, hp = derive_key_iv_hp(CipherSuite.CHACHA20_POLY1305_SHA256, secret) + key, iv, hp = derive_key_iv_hp( + cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, + secret=secret, + version=PROTOCOL_VERSION, + ) self.assertEqual( key, binascii.unhexlify( diff --git a/tests/test_crypto_v2.py b/tests/test_crypto_v2.py new file mode 100644 index 000000000..7a9b16240 --- /dev/null +++ b/tests/test_crypto_v2.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +import binascii +from unittest import TestCase, skipIf + +from qh3.buffer import Buffer +from qh3.quic.crypto import ( + INITIAL_CIPHER_SUITE, + CryptoError, + CryptoPair, + derive_key_iv_hp, +) +from qh3.quic.packet import PACKET_FIXED_BIT, QuicProtocolVersion +from qh3.tls import CipherSuite + +from .utils import SKIP_TESTS +PROTOCOL_VERSION = QuicProtocolVersion.VERSION_2 + +# https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.5 +CHACHA20_CLIENT_PACKET_NUMBER = 654360564 +CHACHA20_CLIENT_PLAIN_HEADER = binascii.unhexlify("4200bff4") +CHACHA20_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify("01") +CHACHA20_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( + "5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba" +) + +# https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.2 +LONG_CLIENT_PACKET_NUMBER = 2 +LONG_CLIENT_PLAIN_HEADER = binascii.unhexlify( + "d36b3343cf088394c8f03e5157080000449e00000002" +) +LONG_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify( + "060040f1010000ed0303ebf8fa56f12939b9584a3896472ec40bb863cfd3e868" + "04fe3a47f06a2b69484c00000413011302010000c000000010000e00000b6578" + "616d706c652e636f6dff01000100000a00080006001d00170018001000070005" + "04616c706e000500050100000000003300260024001d00209370b2c9caa47fba" + "baf4559fedba753de171fa71f50f1ce15d43e994ec74d748002b000302030400" + "0d0010000e0403050306030203080408050806002d00020101001c0002400100" + "3900320408ffffffffffffffff05048000ffff07048000ffff08011001048000" + "75300901100f088394c8f03e51570806048000ffff" +) + bytes(917) +LONG_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( + "d76b3343cf088394c8f03e5157080000449ea0c95e82ffe67b6abcdb4298b485" + "dd04de806071bf03dceebfa162e75d6c96058bdbfb127cdfcbf903388e99ad04" + "9f9a3dd4425ae4d0992cfff18ecf0fdb5a842d09747052f17ac2053d21f57c5d" + "250f2c4f0e0202b70785b7946e992e58a59ac52dea6774d4f03b55545243cf1a" + "12834e3f249a78d395e0d18f4d766004f1a2674802a747eaa901c3f10cda5500" + "cb9122faa9f1df66c392079a1b40f0de1c6054196a11cbea40afb6ef5253cd68" + "18f6625efce3b6def6ba7e4b37a40f7732e093daa7d52190935b8da58976ff33" + "12ae50b187c1433c0f028edcc4c2838b6a9bfc226ca4b4530e7a4ccee1bfa2a3" + "d396ae5a3fb512384b2fdd851f784a65e03f2c4fbe11a53c7777c023462239dd" + "6f7521a3f6c7d5dd3ec9b3f233773d4b46d23cc375eb198c63301c21801f6520" + "bcfb7966fc49b393f0061d974a2706df8c4a9449f11d7f3d2dcbb90c6b877045" + "636e7c0c0fe4eb0f697545460c806910d2c355f1d253bc9d2452aaa549e27a1f" + "ac7cf4ed77f322e8fa894b6a83810a34b361901751a6f5eb65a0326e07de7c12" + "16ccce2d0193f958bb3850a833f7ae432b65bc5a53975c155aa4bcb4f7b2c4e5" + "4df16efaf6ddea94e2c50b4cd1dfe06017e0e9d02900cffe1935e0491d77ffb4" + "fdf85290fdd893d577b1131a610ef6a5c32b2ee0293617a37cbb08b847741c3b" + "8017c25ca9052ca1079d8b78aebd47876d330a30f6a8c6d61dd1ab5589329de7" + "14d19d61370f8149748c72f132f0fc99f34d766c6938597040d8f9e2bb522ff9" + "9c63a344d6a2ae8aa8e51b7b90a4a806105fcbca31506c446151adfeceb51b91" + "abfe43960977c87471cf9ad4074d30e10d6a7f03c63bd5d4317f68ff325ba3bd" + "80bf4dc8b52a0ba031758022eb025cdd770b44d6d6cf0670f4e990b22347a7db" + "848265e3e5eb72dfe8299ad7481a408322cac55786e52f633b2fb6b614eaed18" + "d703dd84045a274ae8bfa73379661388d6991fe39b0d93debb41700b41f90a15" + "c4d526250235ddcd6776fc77bc97e7a417ebcb31600d01e57f32162a8560cacc" + "7e27a096d37a1a86952ec71bd89a3e9a30a2a26162984d7740f81193e8238e61" + "f6b5b984d4d3dfa033c1bb7e4f0037febf406d91c0dccf32acf423cfa1e70710" + "10d3f270121b493ce85054ef58bada42310138fe081adb04e2bd901f2f13458b" + "3d6758158197107c14ebb193230cd1157380aa79cae1374a7c1e5bbcb80ee23e" + "06ebfde206bfb0fcbc0edc4ebec309661bdd908d532eb0c6adc38b7ca7331dce" + "8dfce39ab71e7c32d318d136b6100671a1ae6a6600e3899f31f0eed19e3417d1" + "34b90c9058f8632c798d4490da4987307cba922d61c39805d072b589bd52fdf1" + "e86215c2d54e6670e07383a27bbffb5addf47d66aa85a0c6f9f32e59d85a44dd" + "5d3b22dc2be80919b490437ae4f36a0ae55edf1d0b5cb4e9a3ecabee93dfc6e3" + "8d209d0fa6536d27a5d6fbb17641cde27525d61093f1b28072d111b2b4ae5f89" + "d5974ee12e5cf7d5da4d6a31123041f33e61407e76cffcdcfd7e19ba58cf4b53" + "6f4c4938ae79324dc402894b44faf8afbab35282ab659d13c93f70412e85cb19" + "9a37ddec600545473cfb5a05e08d0b209973b2172b4d21fb69745a262ccde96b" + "a18b2faa745b6fe189cf772a9f84cbfc" +) + +# https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.3 +LONG_SERVER_PACKET_NUMBER = 1 +LONG_SERVER_PLAIN_HEADER = binascii.unhexlify( + "d16b3343cf0008f067a5502a4262b50040750001" +) +LONG_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( + "02000000000600405a020000560303eefce7f7b37ba1d1632e96677825ddf739" + "88cfc79825df566dc5430b9a045a1200130100002e00330024001d00209d3c94" + "0d89690b84d08a60993c144eca684d1081287c834d5311bcf32bb9da1a002b00" + "020304" +) +LONG_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( + "dc6b3343cf0008f067a5502a4262b5004075d92faaf16f05d8a4398c47089698" + "baeea26b91eb761d9b89237bbf87263017915358230035f7fd3945d88965cf17" + "f9af6e16886c61bfc703106fbaf3cb4cfa52382dd16a393e42757507698075b2" + "c984c707f0a0812d8cd5a6881eaf21ceda98f4bd23f6fe1a3e2c43edd9ce7ca8" + "4bed8521e2e140" +) + +SHORT_SERVER_PACKET_NUMBER = 3 +SHORT_SERVER_PLAIN_HEADER = binascii.unhexlify("41b01fd24a586a9cf30003") +SHORT_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( + "06003904000035000151805a4bebf5000020b098c8dc4183e4c182572e10ac3e" + "2b88897e0524c8461847548bd2dffa2c0ae60008002a0004ffffffff" +) +SHORT_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( + "59b01fd24a586a9cf3be262d3eb9b42ada03644d223dae08cbffd5bddab1cf02" + "c33711d0cf5cdc785ce55a4d95c6a82e117ba937080ac6d063915f8c4ee28bd3" + "d86949197c48e8550aa32612f9af806a6c20d6d10ed08f" +) + + +class CryptoTest(TestCase): + """ + Test vectors from: + + https://datatracker.ietf.org/doc/html/rfc9001#appendix-A + """ + + def create_crypto(self, is_client): + pair = CryptoPair() + pair.setup_initial( + cid=binascii.unhexlify("8394c8f03e515708"), + is_client=is_client, + version=PROTOCOL_VERSION, + ) + return pair + + def test_derive_key_iv_hp(self): + # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.1 + + # client + secret = binascii.unhexlify( + "14ec9d6eb9fd7af83bf5a668bc17a7e283766aade7ecd0891f70f9ff7f4bf47b" + ) + key, iv, hp = derive_key_iv_hp( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=secret, + version=PROTOCOL_VERSION, + ) + self.assertEqual(key, binascii.unhexlify("8b1a0bc121284290a29e0971b5cd045d")) + self.assertEqual(iv, binascii.unhexlify("91f73e2351d8fa91660e909f")) + self.assertEqual(hp, binascii.unhexlify("45b95e15235d6f45a6b19cbcb0294ba9")) + + # server + secret = binascii.unhexlify( + "0263db1782731bf4588e7e4d93b7463907cb8cd8200b5da55a8bd488eafc37c1" + ) + key, iv, hp = derive_key_iv_hp( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=secret, + version=PROTOCOL_VERSION, + ) + self.assertEqual(key, binascii.unhexlify("82db637861d55e1d011f19ea71d5d2a7")) + self.assertEqual(iv, binascii.unhexlify("dd13c276499c0249d3310652")) + self.assertEqual(hp, binascii.unhexlify("edf6d05c83121201b436e16877593c3a")) + + @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") + def test_derive_key_iv_hp_chacha20(self): + # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.5 + + # server + secret = binascii.unhexlify( + "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" + ) + key, iv, hp = derive_key_iv_hp( + cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, + secret=secret, + version=PROTOCOL_VERSION, + ) + self.assertEqual( + key, + binascii.unhexlify( + "3bfcddd72bcf02541d7fa0dd1f5f9eeea817e09a6963a0e6c7df0f9a1bab90f2" + ), + ) + self.assertEqual(iv, binascii.unhexlify("a6b5bc6ab7dafce30ffff5dd")) + self.assertEqual( + hp, + binascii.unhexlify( + "d659760d2ba434a226fd37b35c69e2da8211d10c4f12538787d65645d5d1b8e2" + ), + ) + + @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") + def test_decrypt_chacha20(self): + pair = CryptoPair() + pair.recv.setup( + cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, + secret=binascii.unhexlify( + "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" + ), + version=PROTOCOL_VERSION, + ) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + CHACHA20_CLIENT_ENCRYPTED_PACKET, 1, CHACHA20_CLIENT_PACKET_NUMBER + ) + self.assertEqual(plain_header, CHACHA20_CLIENT_PLAIN_HEADER) + self.assertEqual(plain_payload, CHACHA20_CLIENT_PLAIN_PAYLOAD) + self.assertEqual(packet_number, CHACHA20_CLIENT_PACKET_NUMBER) + + def test_decrypt_long_client(self): + pair = self.create_crypto(is_client=False) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + LONG_CLIENT_ENCRYPTED_PACKET, 18, 0 + ) + self.assertEqual(plain_header, LONG_CLIENT_PLAIN_HEADER) + self.assertEqual(plain_payload, LONG_CLIENT_PLAIN_PAYLOAD) + self.assertEqual(packet_number, LONG_CLIENT_PACKET_NUMBER) + + def test_decrypt_long_server(self): + pair = self.create_crypto(is_client=True) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + LONG_SERVER_ENCRYPTED_PACKET, 18, 0 + ) + self.assertEqual(plain_header, LONG_SERVER_PLAIN_HEADER) + self.assertEqual(plain_payload, LONG_SERVER_PLAIN_PAYLOAD) + self.assertEqual(packet_number, LONG_SERVER_PACKET_NUMBER) + + def test_decrypt_no_key(self): + pair = CryptoPair() + with self.assertRaises(CryptoError): + pair.decrypt_packet(LONG_SERVER_ENCRYPTED_PACKET, 18, 0) + + def test_decrypt_short_server(self): + pair = CryptoPair() + pair.recv.setup( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=binascii.unhexlify( + "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" + ), + version=PROTOCOL_VERSION, + ) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + SHORT_SERVER_ENCRYPTED_PACKET, 9, 0 + ) + self.assertEqual(plain_header, SHORT_SERVER_PLAIN_HEADER) + self.assertEqual(plain_payload, SHORT_SERVER_PLAIN_PAYLOAD) + self.assertEqual(packet_number, SHORT_SERVER_PACKET_NUMBER) + + @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") + def test_encrypt_chacha20(self): + pair = CryptoPair() + pair.send.setup( + cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, + secret=binascii.unhexlify( + "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" + ), + version=PROTOCOL_VERSION, + ) + + packet = pair.encrypt_packet( + CHACHA20_CLIENT_PLAIN_HEADER, + CHACHA20_CLIENT_PLAIN_PAYLOAD, + CHACHA20_CLIENT_PACKET_NUMBER, + ) + self.assertEqual(packet, CHACHA20_CLIENT_ENCRYPTED_PACKET) + + def test_encrypt_long_client(self): + pair = self.create_crypto(is_client=True) + + packet = pair.encrypt_packet( + LONG_CLIENT_PLAIN_HEADER, + LONG_CLIENT_PLAIN_PAYLOAD, + LONG_CLIENT_PACKET_NUMBER, + ) + self.assertEqual(packet, LONG_CLIENT_ENCRYPTED_PACKET) + + def test_encrypt_long_server(self): + pair = self.create_crypto(is_client=False) + + packet = pair.encrypt_packet( + LONG_SERVER_PLAIN_HEADER, + LONG_SERVER_PLAIN_PAYLOAD, + LONG_SERVER_PACKET_NUMBER, + ) + self.assertEqual(packet, LONG_SERVER_ENCRYPTED_PACKET) + + def test_encrypt_short_server(self): + pair = CryptoPair() + pair.send.setup( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=binascii.unhexlify( + "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" + ), + version=PROTOCOL_VERSION, + ) + + packet = pair.encrypt_packet( + SHORT_SERVER_PLAIN_HEADER, + SHORT_SERVER_PLAIN_PAYLOAD, + SHORT_SERVER_PACKET_NUMBER, + ) + self.assertEqual(packet, SHORT_SERVER_ENCRYPTED_PACKET) + + def test_key_update(self): + pair1 = self.create_crypto(is_client=True) + pair2 = self.create_crypto(is_client=False) + + def create_packet(key_phase, packet_number): + buf = Buffer(capacity=100) + buf.push_uint8(PACKET_FIXED_BIT | key_phase << 2 | 1) + buf.push_bytes(binascii.unhexlify("8394c8f03e515708")) + buf.push_uint16(packet_number) + return buf.data, b"\x00\x01\x02\x03" + + def send(sender, receiver, packet_number=0): + plain_header, plain_payload = create_packet( + key_phase=sender.key_phase, packet_number=packet_number + ) + encrypted = sender.encrypt_packet( + plain_header, plain_payload, packet_number + ) + recov_header, recov_payload, recov_packet_number = receiver.decrypt_packet( + encrypted, len(plain_header) - 2, 0 + ) + self.assertEqual(recov_header, plain_header) + self.assertEqual(recov_payload, plain_payload) + self.assertEqual(recov_packet_number, packet_number) + + # roundtrip + send(pair1, pair2, 0) + send(pair2, pair1, 0) + self.assertEqual(pair1.key_phase, 0) + self.assertEqual(pair2.key_phase, 0) + + # pair 1 key update + pair1.update_key() + + # roundtrip + send(pair1, pair2, 1) + send(pair2, pair1, 1) + self.assertEqual(pair1.key_phase, 1) + self.assertEqual(pair2.key_phase, 1) + + # pair 2 key update + pair2.update_key() + + # roundtrip + send(pair2, pair1, 2) + send(pair1, pair2, 2) + self.assertEqual(pair1.key_phase, 0) + self.assertEqual(pair2.key_phase, 0) + + # pair 1 key - update, but not next to send + pair1.update_key() + + # roundtrip + send(pair2, pair1, 3) + send(pair1, pair2, 3) + self.assertEqual(pair1.key_phase, 1) + self.assertEqual(pair2.key_phase, 1) diff --git a/tests/test_h3.py b/tests/test_h3.py index 4087a6d3b..53558b42b 100644 --- a/tests/test_h3.py +++ b/tests/test_h3.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import binascii import contextlib import copy diff --git a/tests/test_logger.py b/tests/test_logger.py index b103ec579..075f2530b 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import tempfile diff --git a/tests/test_packet.py b/tests/test_packet.py index 522c8e963..605d8d6d6 100644 --- a/tests/test_packet.py +++ b/tests/test_packet.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import binascii from unittest import TestCase from qh3.buffer import Buffer, BufferReadError from qh3.quic import packet from qh3.quic.packet import ( - PACKET_TYPE_INITIAL, - PACKET_TYPE_RETRY, + QuicPacketType, QuicPreferredAddress, QuicProtocolVersion, QuicTransportParameters, @@ -20,7 +21,10 @@ push_quic_transport_parameters, ) -from .utils import load +from .test_crypto_v1 import LONG_CLIENT_ENCRYPTED_PACKET as CLIENT_INITIAL_V1 +from .test_crypto_v1 import LONG_SERVER_ENCRYPTED_PACKET as SERVER_INITIAL_V1 +from .test_crypto_v2 import LONG_CLIENT_ENCRYPTED_PACKET as CLIENT_INITIAL_V2 +from .test_crypto_v2 import LONG_SERVER_ENCRYPTED_PACKET as SERVER_INITIAL_V2 class PacketTest(TestCase): @@ -51,76 +55,129 @@ def test_pull_empty(self): with self.assertRaises(BufferReadError): pull_quic_header(buf, host_cid_length=8) - def test_pull_initial_client(self): - buf = Buffer(data=load("initial_client.bin")) + def test_pull_initial_client_v1(self): + buf = Buffer(data=CLIENT_INITIAL_V1) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) - self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) - self.assertEqual(header.destination_cid, binascii.unhexlify("858b39368b8e3c6e")) + self.assertEqual(header.packet_type, QuicPacketType.INITIAL) + self.assertEqual(header.packet_length, 1200) + self.assertEqual(header.destination_cid, binascii.unhexlify("8394c8f03e515708")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 1262) self.assertEqual(buf.tell(), 18) - def test_pull_initial_client_truncated(self): - buf = Buffer(data=load("initial_client.bin")[0:100]) + def test_pull_initial_client_v1_truncated(self): + buf = Buffer(data=CLIENT_INITIAL_V1[0:100]) with self.assertRaises(ValueError) as cm: pull_quic_header(buf, host_cid_length=8) self.assertEqual(str(cm.exception), "Packet payload is truncated") - def test_pull_initial_server(self): - buf = Buffer(data=load("initial_server.bin")) + def test_pull_initial_client_v2(self): + buf = Buffer(data=CLIENT_INITIAL_V2) + header = pull_quic_header(buf, host_cid_length=8) + self.assertEqual(header.version, QuicProtocolVersion.VERSION_2) + self.assertEqual(header.packet_type, QuicPacketType.INITIAL) + self.assertEqual(header.packet_length, 1200) + self.assertEqual(header.destination_cid, binascii.unhexlify("8394c8f03e515708")) + self.assertEqual(header.source_cid, b"") + self.assertEqual(header.token, b"") + self.assertEqual(header.integrity_tag, b"") + self.assertEqual(buf.tell(), 18) + + def test_pull_initial_server_v1(self): + buf = Buffer(data=SERVER_INITIAL_V1) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) - self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) + self.assertEqual(header.packet_type, QuicPacketType.INITIAL) + self.assertEqual(header.packet_length, 135) self.assertEqual(header.destination_cid, b"") - self.assertEqual(header.source_cid, binascii.unhexlify("195c68344e28d479")) + self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 184) self.assertEqual(buf.tell(), 18) - def test_pull_retry(self): - original_destination_cid = binascii.unhexlify("fbbd219b7363b64b") + def test_pull_initial_server_v2(self): + buf = Buffer(data=SERVER_INITIAL_V2) + header = pull_quic_header(buf, host_cid_length=8) + self.assertEqual(header.version, QuicProtocolVersion.VERSION_2) + self.assertEqual(header.packet_type, QuicPacketType.INITIAL) + self.assertEqual(header.packet_length, 135) + self.assertEqual(header.destination_cid, b"") + self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) + self.assertEqual(header.token, b"") + self.assertEqual(header.integrity_tag, b"") + self.assertEqual(buf.tell(), 18) + + def test_pull_retry_v1(self): + # https://datatracker.ietf.org/doc/html/rfc9001#appendix-A.4 + original_destination_cid = binascii.unhexlify("8394c8f03e515708") - data = load("retry.bin") + data = binascii.unhexlify( + "ff000000010008f067a5502a4262b5746f6b656e04a265ba2eff4d829058fb3f0f2496ba" + ) buf = Buffer(data=data) - header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) + header = pull_quic_header(buf) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) - self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) - self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e")) + self.assertEqual(header.packet_type, QuicPacketType.RETRY) + self.assertEqual(header.packet_length, 36) + self.assertEqual(header.destination_cid, b"") + self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) + self.assertEqual(header.token, b"token") self.assertEqual( - header.source_cid, - binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"), + header.integrity_tag, binascii.unhexlify("04a265ba2eff4d829058fb3f0f2496ba") ) + self.assertEqual(buf.tell(), 36) + + # check integrity self.assertEqual( - header.token, - binascii.unhexlify( - "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172" - "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e" - "8530741a99192ee56914c5626998ec0f" + get_retry_integrity_tag( + buf.data_slice(0, 20), original_destination_cid, version=header.version ), + header.integrity_tag, + ) + + # serialize + encoded = encode_quic_retry( + version=header.version, + source_cid=header.source_cid, + destination_cid=header.destination_cid, + original_destination_cid=original_destination_cid, + retry_token=header.token, + # This value is arbitrary, we set it to match the value in the RFC. + unused=0xF, ) + with open("bob.bin", "wb") as fp: + fp.write(encoded) + self.assertEqual(encoded, data) + + def test_pull_retry_v2(self): + # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.4 + original_destination_cid = binascii.unhexlify("8394c8f03e515708") + + data = binascii.unhexlify( + "cf6b3343cf0008f067a5502a4262b5746f6b656ec8646ce8bfe33952d955543665dcc7b6" + ) + buf = Buffer(data=data) + header = pull_quic_header(buf) + self.assertEqual(header.version, QuicProtocolVersion.VERSION_2) + self.assertEqual(header.packet_type, QuicPacketType.RETRY) + self.assertEqual(header.packet_length, 36) + self.assertEqual(header.destination_cid, b"") + self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) + self.assertEqual(header.token, b"token") self.assertEqual( - header.integrity_tag, binascii.unhexlify("4620aafd42f1d630588b27575a12da5c") + header.integrity_tag, binascii.unhexlify("c8646ce8bfe33952d955543665dcc7b6") ) - self.assertEqual(header.rest_length, 0) - self.assertEqual(buf.tell(), 125) + self.assertEqual(buf.tell(), 36) # check integrity - if False: - self.assertEqual( - get_retry_integrity_tag( - buf.data_slice(0, 109), - original_destination_cid, - version=header.version, - ), - header.integrity_tag, - ) + self.assertEqual( + get_retry_integrity_tag( + buf.data_slice(0, 20), original_destination_cid, version=header.version + ), + header.integrity_tag, + ) # serialize encoded = encode_quic_retry( @@ -129,28 +186,39 @@ def test_pull_retry(self): destination_cid=header.destination_cid, original_destination_cid=original_destination_cid, retry_token=header.token, + # This value is arbitrary, we set it to match the value in the RFC. + unused=0xF, ) with open("bob.bin", "wb") as fp: fp.write(encoded) self.assertEqual(encoded, data) def test_pull_version_negotiation(self): - buf = Buffer(data=load("version_negotiation.bin")) + data = binascii.unhexlify( + "ea00000000089aac5a49ba87a84908f92f4336fa951ba14547471600000001" + ) + buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) - self.assertEqual(header.packet_type, None) + self.assertEqual(header.packet_type, QuicPacketType.VERSION_NEGOTIATION) + self.assertEqual(header.packet_length, 31) self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849")) self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 8) - self.assertEqual(buf.tell(), 23) + self.assertEqual( + header.supported_versions, [0x45474716, QuicProtocolVersion.VERSION_1] + ) + self.assertEqual(buf.tell(), 31) - versions = [] - while not buf.eof(): - versions.append(buf.pull_uint32()) - self.assertEqual(versions, [0x45474716, QuicProtocolVersion.VERSION_1]), + encoded = encode_quic_version_negotiation( + destination_cid=header.destination_cid, + source_cid=header.source_cid, + supported_versions=header.supported_versions, + ) + + # The first byte may differ as it is random. + self.assertEqual(encoded[1:], data[1:]) def test_pull_long_header_dcid_too_long(self): buf = Buffer( @@ -186,16 +254,17 @@ def test_pull_long_header_too_short(self): pull_quic_header(buf, host_cid_length=8) def test_pull_short_header(self): - buf = Buffer(data=load("short_header.bin")) + buf = Buffer( + data=binascii.unhexlify("5df45aa7b59c0e1ad6e668f5304cd4fd1fb3799327") + ) header = pull_quic_header(buf, host_cid_length=8) - self.assertFalse(header.is_long_header) self.assertEqual(header.version, None) - self.assertEqual(header.packet_type, 0x50) + self.assertEqual(header.packet_type, QuicPacketType.ONE_RTT) + self.assertEqual(header.packet_length, 21) self.assertEqual(header.destination_cid, binascii.unhexlify("f45aa7b59c0e1ad6")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 12) self.assertEqual(buf.tell(), 9) def test_pull_short_header_no_fixed_bit(self): @@ -204,14 +273,6 @@ def test_pull_short_header_no_fixed_bit(self): pull_quic_header(buf, host_cid_length=8) self.assertEqual(str(cm.exception), "Packet fixed bit is zero") - def test_encode_quic_version_negotiation(self): - data = encode_quic_version_negotiation( - destination_cid=binascii.unhexlify("9aac5a49ba87a849"), - source_cid=binascii.unhexlify("f92f4336fa951ba1"), - supported_versions=[0x45474716, QuicProtocolVersion.VERSION_1], - ) - self.assertEqual(data[1:], load("version_negotiation.bin")[1:]) - class ParamsTest(TestCase): maxDiff = None diff --git a/tests/test_packet_builder.py b/tests/test_packet_builder.py index 6ad300589..fa62d5369 100644 --- a/tests/test_packet_builder.py +++ b/tests/test_packet_builder.py @@ -1,13 +1,9 @@ +from __future__ import annotations + from unittest import TestCase from qh3.quic.crypto import CryptoPair -from qh3.quic.packet import ( - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_ONE_RTT, - QuicFrameType, - QuicProtocolVersion, -) +from qh3.quic.packet import QuicFrameType, QuicPacketType, QuicProtocolVersion from qh3.quic.packet_builder import ( QuicPacketBuilder, QuicPacketBuilderStop, @@ -16,6 +12,10 @@ from qh3.tls import Epoch +def datagram_sizes(datagrams: list[bytes]) -> list[int]: + return [len(x) for x in datagrams] + + def create_builder(is_client=False): return QuicPacketBuilder( host_cid=bytes(8), @@ -41,7 +41,7 @@ def test_long_header_empty(self): builder = create_builder() crypto = create_crypto() - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1236) self.assertTrue(builder.packet_is_empty) @@ -58,14 +58,14 @@ def test_long_header_padding(self): crypto = create_crypto() # INITIAL, fully padded - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1236) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -81,8 +81,8 @@ def test_long_header_padding(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, - sent_bytes=1280, + packet_type=QuicPacketType.INITIAL, + sent_bytes=145, ) ], ) @@ -91,25 +91,26 @@ def test_long_header_padding(self): self.assertEqual(builder.packet_number, 1) def test_long_header_initial_client_2(self): + self.maxDiff = None builder = create_builder(is_client=True) crypto = create_crypto() # INITIAL, full length - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1236) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # INITIAL - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1236) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -126,7 +127,7 @@ def test_long_header_initial_client_2(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1280, ), QuicSentPacket( @@ -135,8 +136,8 @@ def test_long_header_initial_client_2(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, - packet_type=PACKET_TYPE_INITIAL, - sent_bytes=1280, + packet_type=QuicPacketType.INITIAL, + sent_bytes=145, ), ], ) @@ -149,20 +150,20 @@ def test_long_header_initial_server(self): crypto = create_crypto() # INITIAL - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1236) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) - self.assertEqual(len(datagrams[0]), 145) + self.assertEqual(len(datagrams[0]), 1280) self.assertEqual( packets, [ @@ -172,7 +173,7 @@ def test_long_header_initial_server(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=145, ) ], @@ -190,7 +191,7 @@ def test_long_header_ping_only(self): crypto = create_crypto() # HANDSHAKE, with only a PING frame - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) builder.start_frame(QuicFrameType.PING) self.assertFalse(builder.packet_is_empty) @@ -207,7 +208,7 @@ def test_long_header_ping_only(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_HANDSHAKE, + packet_type=QuicPacketType.HANDSHAKE, sent_bytes=45, ) ], @@ -218,25 +219,25 @@ def test_long_header_then_short_header(self): crypto = create_crypto() # INITIAL, full length - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1236) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # ONE_RTT, full length - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1253) buf = builder.start_frame(QuicFrameType.STREAM_BASE) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # ONE_RTT, empty - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -253,7 +254,7 @@ def test_long_header_then_short_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1280, ), QuicSentPacket( @@ -262,7 +263,7 @@ def test_long_header_then_short_header(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=1, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, ), ], @@ -271,26 +272,71 @@ def test_long_header_then_short_header(self): # check builder self.assertEqual(builder.packet_number, 2) + def test_long_header_initial_client_zero_rtt(self): + builder = create_builder(is_client=True) + crypto = create_crypto() + + # INITIAL + builder.start_packet(QuicPacketType.INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(613)) + self.assertFalse(builder.packet_is_empty) + + # 0-RTT + builder.start_packet(QuicPacketType.ZERO_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 579) + buf = builder.start_frame(QuicFrameType.STREAM_BASE) + buf.push_bytes(bytes(100)) + self.assertFalse(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(datagram_sizes(datagrams), [1280]) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=QuicPacketType.INITIAL, + sent_bytes=658, + ), + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=1, + packet_type=QuicPacketType.ZERO_RTT, + sent_bytes=144, + ), + ], + ) + def test_long_header_then_long_header(self): builder = create_builder() crypto = create_crypto() # INITIAL - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1236) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(199)) self.assertFalse(builder.packet_is_empty) # HANDSHAKE - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertEqual(builder.remaining_flight_space, 993) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(299)) self.assertFalse(builder.packet_is_empty) # ONE_RTT - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 666) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(299)) @@ -299,7 +345,7 @@ def test_long_header_then_long_header(self): # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) - self.assertEqual(len(datagrams[0]), 914) + self.assertEqual(len(datagrams[0]), 1280) self.assertEqual( packets, [ @@ -309,7 +355,7 @@ def test_long_header_then_long_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=244, ), QuicSentPacket( @@ -318,7 +364,7 @@ def test_long_header_then_long_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, - packet_type=PACKET_TYPE_HANDSHAKE, + packet_type=QuicPacketType.HANDSHAKE, sent_bytes=343, ), QuicSentPacket( @@ -327,8 +373,8 @@ def test_long_header_then_long_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=2, - packet_type=PACKET_TYPE_ONE_RTT, - sent_bytes=327, + packet_type=QuicPacketType.ONE_RTT, + sent_bytes=693, ), ], ) @@ -340,7 +386,7 @@ def test_short_header_empty(self): builder = create_builder() crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1253) self.assertTrue(builder.packet_is_empty) @@ -357,7 +403,7 @@ def test_short_header_padding(self): crypto = create_crypto() # ONE_RTT, full length - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1253) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) @@ -376,7 +422,7 @@ def test_short_header_padding(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, ) ], @@ -394,14 +440,14 @@ def test_short_header_max_flight_bytes(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 973) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams @@ -417,7 +463,7 @@ def test_short_header_max_flight_bytes(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1000, ), ], @@ -438,7 +484,7 @@ def test_short_header_max_flight_bytes_zero(self): crypto = create_crypto() with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams @@ -459,12 +505,12 @@ def test_short_header_max_flight_bytes_zero_ack(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) buf = builder.start_frame(QuicFrameType.ACK) buf.push_bytes(bytes(64)) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams @@ -480,7 +526,7 @@ def test_short_header_max_flight_bytes_zero_ack(self): is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=92, ), ], @@ -499,7 +545,7 @@ def test_short_header_max_total_bytes_1(self): crypto = create_crypto() with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() @@ -518,14 +564,14 @@ def test_short_header_max_total_bytes_2(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 773) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() @@ -540,7 +586,7 @@ def test_short_header_max_total_bytes_2(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=800, ) ], @@ -555,20 +601,20 @@ def test_short_header_max_total_bytes_3(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1253) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 693) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() @@ -584,7 +630,7 @@ def test_short_header_max_total_bytes_3(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, ), QuicSentPacket( @@ -593,7 +639,7 @@ def test_short_header_max_total_bytes_3(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=720, ), ], @@ -611,7 +657,7 @@ def test_short_header_ping_only(self): crypto = create_crypto() # HANDSHAKE, with only a PING frame - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.PING) self.assertFalse(builder.packet_is_empty) @@ -628,7 +674,7 @@ def test_short_header_ping_only(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=29, ) ], diff --git a/tests/test_rangeset.py b/tests/test_rangeset.py index 734e20b9e..5e0fdf460 100644 --- a/tests/test_rangeset.py +++ b/tests/test_rangeset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase from qh3.quic.rangeset import RangeSet diff --git a/tests/test_recovery.py b/tests/test_recovery.py index 18f374c72..b1312dba6 100644 --- a/tests/test_recovery.py +++ b/tests/test_recovery.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import math from unittest import TestCase from qh3 import tls -from qh3.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT +from qh3.quic.packet import QuicPacketType from qh3.quic.packet_builder import QuicSentPacket from qh3.quic.rangeset import RangeSet from qh3.quic.recovery import ( @@ -88,7 +90,7 @@ def test_on_ack_received_ack_eliciting(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) @@ -121,7 +123,7 @@ def test_on_ack_received_non_ack_eliciting(self): is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=123.45, ) @@ -154,7 +156,7 @@ def test_on_packet_lost_crypto(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1280, sent_time=0.0, ) diff --git a/tests/test_retry.py b/tests/test_retry.py index b30efbcc1..f05a1355c 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase from qh3.quic.retry import QuicRetryTokenHandler diff --git a/tests/test_stream.py b/tests/test_stream.py index 0e606c529..2679eca80 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase from qh3.quic.events import StreamDataReceived, StreamReset diff --git a/tests/test_tls.py b/tests/test_tls.py index f6760d541..57b031561 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import binascii import ssl from unittest import TestCase @@ -83,9 +85,9 @@ def test_pull_block_truncated(self): def create_buffers(): return { - tls.Epoch.INITIAL: Buffer(capacity=4096), - tls.Epoch.HANDSHAKE: Buffer(capacity=4096), - tls.Epoch.ONE_RTT: Buffer(capacity=4096), + tls.Epoch.INITIAL: Buffer(capacity=8192), + tls.Epoch.HANDSHAKE: Buffer(capacity=8192), + tls.Epoch.ONE_RTT: Buffer(capacity=8192), } @@ -324,7 +326,7 @@ def _handshake(self, client, server): self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertGreaterEqual(len(server_input), 181) - self.assertLessEqual(len(server_input), 369) + self.assertLessEqual(len(server_input), 1800) reset_buffers(client_buf) # Handle client hello. @@ -336,7 +338,7 @@ def _handshake(self, client, server): self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) client_input = merge_buffers(server_buf) self.assertGreaterEqual(len(client_input), 539) - self.assertLessEqual(len(client_input), 2316) + self.assertLessEqual(len(client_input), 4000) reset_buffers(server_buf) @@ -533,7 +535,7 @@ def second_handshake(): self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertGreaterEqual(len(server_input), 383) - self.assertLessEqual(len(server_input), 483) + self.assertLessEqual(len(server_input), 1800) reset_buffers(client_buf) # Handle client hello. @@ -543,7 +545,7 @@ def second_handshake(): server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) client_input = merge_buffers(server_buf) - self.assertEqual(len(client_input), 226) + self.assertEqual(len(client_input), 1410) reset_buffers(server_buf) # Handle server hello, encrypted extensions, certificate, @@ -586,7 +588,7 @@ def second_handshake_bad_binder(): self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertGreaterEqual(len(server_input), 383) - self.assertLessEqual(len(server_input), 483) + self.assertLessEqual(len(server_input), 1800) reset_buffers(client_buf) # tamper with binder @@ -612,7 +614,7 @@ def second_handshake_bad_pre_shared_key(): self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertGreaterEqual(len(server_input), 383) - self.assertLessEqual(len(server_input), 483) + self.assertLessEqual(len(server_input), 1800) reset_buffers(client_buf) # handle client hello @@ -626,7 +628,7 @@ def second_handshake_bad_pre_shared_key(): buf.seek(buf.tell() - 1) buf.push_uint8(1) client_input = merge_buffers(server_buf) - self.assertEqual(len(client_input), 226) + self.assertEqual(len(client_input), 1410) reset_buffers(server_buf) # handle server hello and bomb @@ -887,6 +889,7 @@ def test_pull_client_hello_with_sni(self): # serialize buf = Buffer(1000) push_client_hello(buf, hello) + self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin")) def test_push_client_hello(self): diff --git a/tests/test_webtransport.py b/tests/test_webtransport.py index eb7863757..79965ac85 100644 --- a/tests/test_webtransport.py +++ b/tests/test_webtransport.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase from qh3.h3.connection import H3_ALPN, ErrorCode, H3Connection diff --git a/tests/tls_client_hello.bin b/tests/tls_client_hello.bin index e6ca00d52..7c7dcc3cf 100644 Binary files a/tests/tls_client_hello.bin and b/tests/tls_client_hello.bin differ diff --git a/tests/tls_client_hello_with_alpn.bin b/tests/tls_client_hello_with_alpn.bin index 9113a3e75..ca78dbfad 100644 Binary files a/tests/tls_client_hello_with_alpn.bin and b/tests/tls_client_hello_with_alpn.bin differ diff --git a/tests/tls_client_hello_with_psk.bin b/tests/tls_client_hello_with_psk.bin index cf695c3a7..f0a0cdac3 100644 Binary files a/tests/tls_client_hello_with_psk.bin and b/tests/tls_client_hello_with_psk.bin differ diff --git a/tests/tls_client_hello_with_sni.bin b/tests/tls_client_hello_with_sni.bin index b568a8616..c9fedcc04 100644 Binary files a/tests/tls_client_hello_with_sni.bin and b/tests/tls_client_hello_with_sni.bin differ diff --git a/tests/utils.py b/tests/utils.py index 374d2320d..da21ce955 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import datetime import functools