From 919fec9318fe6f82a0e792eb89bdb34597befcbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 4 Oct 2023 13:41:21 +0200 Subject: [PATCH 1/9] Add opentelemetry-api as opt-in dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0f21b70a..030c4ac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "requests-toolbelt", ] subscriptions = ["websockets~=11.0"] +telemetry = ["opentelemetry-api"] [project.scripts] ariadne-codegen = "ariadne_codegen.main:main" From 2b21f364b58e30213bb909423ec531bf25372a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 4 Oct 2023 13:42:45 +0200 Subject: [PATCH 2/9] Add telemetry options to BaseClient --- .github/workflows/tests.yml | 2 +- .../dependencies/base_client.py | 138 +++++++++++++++--- .../dependencies/test_base_client.py | 98 +++++++++++++ 3 files changed, 220 insertions(+), 18 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ba4d7220..37ce6f88 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,7 +25,7 @@ jobs: run: | python -m pip install --upgrade pip pip install wheel - pip install -e .[subscriptions,dev] + pip install -e .[subscriptions,telemetry,dev] - name: Pytest run: | pytest diff --git a/ariadne_codegen/client_generators/dependencies/base_client.py b/ariadne_codegen/client_generators/dependencies/base_client.py index 90993cf1..0241fb04 100644 --- a/ariadne_codegen/client_generators/dependencies/base_client.py +++ b/ariadne_codegen/client_generators/dependencies/base_client.py @@ -1,5 +1,5 @@ import json -from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, cast +from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, Union, cast import httpx from pydantic import BaseModel @@ -12,6 +12,26 @@ GraphQlClientInvalidResponseError, ) +try: + from opentelemetry.trace import ( + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="BaseClient") @@ -21,12 +41,21 @@ def __init__( url: str = "", headers: Optional[Dict[str, str]] = None, http_client: Optional[httpx.Client] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = http_client if http_client else httpx.Client(headers=headers) + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + def __enter__(self: Self) -> Self: return self @@ -41,17 +70,9 @@ def __exit__( def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} - - if files and files_map: - return self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return self._execute_json(payload=payload) + if self.tracer: + return self._execute_with_telemetry(query=query, variables=variables) + return self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> dict[str, Any]: if not response.is_success: @@ -77,6 +98,21 @@ def get_data(self, response: httpx.Response) -> dict[str, Any]: return cast(dict[str, Any], data) + def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -148,19 +184,87 @@ def separate_files(path: str, obj: Any) -> Any: def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return self.http_client.post(url=self.url, data=data, files=files) - def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + def _execute_json(self, query: str, variables: Dict[str, Any]) -> httpx.Response: return self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) + + def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + + return self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + def _execute_json_with_telemetry( + self, root_span: Span, query: str, variables: Dict[str, Any] + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + + return self._execute_json(query=query, variables=variables) diff --git a/tests/client_generators/dependencies/test_base_client.py b/tests/client_generators/dependencies/test_base_client.py index 24fd4087..aba53ff7 100644 --- a/tests/client_generators/dependencies/test_base_client.py +++ b/tests/client_generators/dependencies/test_base_client.py @@ -1,6 +1,7 @@ import json from datetime import datetime from typing import Any, Optional +from unittest.mock import ANY import httpx import pytest @@ -445,3 +446,100 @@ def test_base_client_used_as_context_manager_closes_http_client(mocker): base_client.execute("") assert fake_client.close.called + + +@pytest.fixture +def mocker_get_tracer(mocker): + return mocker.patch( + "ariadne_codegen.client_generators.dependencies.base_client.get_tracer" + ) + + +@pytest.fixture +def mocked_start_as_current_span(mocker_get_tracer): + return mocker_get_tracer.return_value.start_as_current_span + + +def test_base_client_with_given_tracker_str_uses_global_tracker(mocker_get_tracer): + BaseClient(url="http://base_url", tracer="tracker name") + + assert mocker_get_tracer.call_count == 1 + + +def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): + httpx_mock.add_response() + client = BaseClient(url="http://base_url", tracer="tracker") + + client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("GraphQL Operation", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +def test_execute_creates_root_span_with_custom_name( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClient( + url="http://base_url", tracer="tracker", root_span_name="root_span" + ) + + client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("root_span", context=ANY) + + +def test_execute_creates_root_span_with_custom_context( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClient( + url="http://base_url", tracer="tracker", root_context={"abc": 123} + ) + + client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call( + "GraphQL Operation", context={"abc": 123} + ) + + +def test_execute_creates_span_for_json_http_request( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClient(url="http://base_url", tracer="tracker") + + client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) + + mocked_start_as_current_span.assert_any_call("json request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call("query", "query GetHello { hello }") + span.set_attribute.assert_any_call( + "variables", json.dumps({"a": 1, "b": {"bb": 2}}) + ) + + +def test_execute_creates_span_for_multipart_request( + httpx_mock, txt_file, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClient(url="http://base_url", tracer="tracker") + + client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file, "a": 1.0, "b": {"bb": 2}}, + ) + + mocked_start_as_current_span.assert_any_call("multipart request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call( + "query", "query Abc($file: Upload!) { abc(file: $file) }" + ) + span.set_attribute.assert_any_call( + "variables", json.dumps({"file": None, "a": 1.0, "b": {"bb": 2}}) + ) + span.set_attribute.assert_any_call("map", json.dumps({"0": ["variables.file"]})) From 76a1c01c7b821e3d2975739e5eac86c9c3e75b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 4 Oct 2023 15:42:59 +0200 Subject: [PATCH 3/9] Add telemetry options to AsyncBaseClient's not subscription operations --- .../dependencies/async_base_client.py | 154 ++++++++++++++++-- .../dependencies/base_client.py | 2 +- .../dependencies/test_async_base_client.py | 106 ++++++++++++ .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- .../expected_client/async_base_client.py | 154 ++++++++++++++++-- 13 files changed, 1625 insertions(+), 177 deletions(-) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index def14347..6b52985c 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/ariadne_codegen/client_generators/dependencies/base_client.py b/ariadne_codegen/client_generators/dependencies/base_client.py index 0241fb04..d6d271b8 100644 --- a/ariadne_codegen/client_generators/dependencies/base_client.py +++ b/ariadne_codegen/client_generators/dependencies/base_client.py @@ -13,7 +13,7 @@ ) try: - from opentelemetry.trace import ( + from opentelemetry.trace import ( # type: ignore[attr-defined] Context, Span, Tracer, diff --git a/tests/client_generators/dependencies/test_async_base_client.py b/tests/client_generators/dependencies/test_async_base_client.py index b7281164..277219a0 100644 --- a/tests/client_generators/dependencies/test_async_base_client.py +++ b/tests/client_generators/dependencies/test_async_base_client.py @@ -1,6 +1,7 @@ import json from datetime import datetime from typing import Any, Optional +from unittest.mock import ANY import httpx import pytest @@ -473,3 +474,108 @@ async def test_base_client_used_as_context_manager_closes_http_client(mocker): await base_client.execute("") assert fake_client.aclose.called + + +@pytest.fixture +def mocker_get_tracer(mocker): + return mocker.patch( + "ariadne_codegen.client_generators.dependencies.async_base_client.get_tracer" + ) + + +@pytest.fixture +def mocked_start_as_current_span(mocker_get_tracer): + return mocker_get_tracer.return_value.start_as_current_span + + +@pytest.mark.asyncio +async def test_async_base_client_with_given_tracker_str_uses_global_tracker( + mocker_get_tracer, +): + AsyncBaseClient(url="http://base_url", tracer="tracker name") + + assert mocker_get_tracer.call_count == 1 + + +@pytest.mark.asyncio +async def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): + httpx_mock.add_response() + client = AsyncBaseClient(url="http://base_url", tracer="tracker") + + await client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("GraphQL Operation", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_creates_root_span_with_custom_name( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClient( + url="http://base_url", tracer="tracker", root_span_name="root_span" + ) + + await client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("root_span", context=ANY) + + +@pytest.mark.asyncio +async def test_execute_creates_root_span_with_custom_context( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClient( + url="http://base_url", tracer="tracker", root_context={"abc": 123} + ) + + await client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call( + "GraphQL Operation", context={"abc": 123} + ) + + +@pytest.mark.asyncio +async def test_execute_creates_span_for_json_http_request( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClient(url="http://base_url", tracer="tracker") + + await client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) + + mocked_start_as_current_span.assert_any_call("json request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call("query", "query GetHello { hello }") + span.set_attribute.assert_any_call( + "variables", json.dumps({"a": 1, "b": {"bb": 2}}) + ) + + +@pytest.mark.asyncio +async def test_execute_creates_span_for_multipart_request( + httpx_mock, txt_file, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClient(url="http://base_url", tracer="tracker") + + await client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file, "a": 1.0, "b": {"bb": 2}}, + ) + + mocked_start_as_current_span.assert_any_call("multipart request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call( + "query", "query Abc($file: Upload!) { abc(file: $file) }" + ) + span.set_attribute.assert_any_call( + "variables", json.dumps({"file": None, "a": 1.0, "b": {"bb": 2}}) + ) + span.set_attribute.assert_any_call("map", json.dumps({"0": ["variables.file"]})) diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index def14347..6b52985c 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -1,6 +1,17 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from uuid import uuid4 import httpx @@ -32,6 +43,26 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument Subprotocol = Any # type: ignore +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -58,17 +89,27 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) + self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + async def __aenter__(self: Self) -> Self: return self @@ -83,17 +124,10 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - payload: Dict[str, Any] = {"query": query, "variables": processed_variables} + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) - if files and files_map: - return await self._execute_multipart( - payload=payload, - files=files, - files_map=files_map, - ) - - return await self._execute_json(payload=payload) + return await self._execute(query=query, variables=variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -142,6 +176,21 @@ async def execute_ws( if data: yield data + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -213,21 +262,31 @@ def separate_files(path: str, obj: Any) -> Any: async def _execute_multipart( self, - payload: Dict[str, Any], + query: str, + variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], ) -> httpx.Response: data = { - "operations": json.dumps(payload, default=to_jsonable_python), + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post(url=self.url, data=data, files=files) - async def _execute_json(self, payload: Dict[str, Any]) -> httpx.Response: - content = json.dumps(payload, default=to_jsonable_python) + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: return await self.http_client.post( - url=self.url, content=content, headers={"Content-Type": "application/json"} + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: @@ -287,3 +346,66 @@ async def _handle_ws_message( ) return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) From 2c3558337ae61f4352ff804793d49ebed3fc392e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 6 Oct 2023 12:42:00 +0200 Subject: [PATCH 4/9] Add telemetry to subscriptions --- .../dependencies/async_base_client.py | 178 ++++++++++++++++-- .../dependencies/test_websockets.py | 129 +++++++++++++ .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- .../expected_client/async_base_client.py | 178 ++++++++++++++++-- 12 files changed, 1867 insertions(+), 220 deletions(-) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/client_generators/dependencies/test_websockets.py b/tests/client_generators/dependencies/test_websockets.py index 995e2a30..d6e9dd10 100644 --- a/tests/client_generators/dependencies/test_websockets.py +++ b/tests/client_generators/dependencies/test_websockets.py @@ -1,4 +1,5 @@ import json +from unittest.mock import ANY import pytest @@ -213,3 +214,131 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type with pytest.raises(GraphQLClientGraphQLMultiError): async for _ in AsyncBaseClient().execute_ws(""): pass + + +@pytest.fixture +def mocked_start_as_current_span(mocker): + mocker_get_tracer = mocker.patch( + "ariadne_codegen.client_generators.dependencies.async_base_client.get_tracer" + ) + return mocker_get_tracer.return_value.start_as_current_span + + +@pytest.mark.asyncio +async def test_execute_ws_creates_root_span( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument +): + client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("GraphQL Subscription", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_root_span_with_custom_name( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument +): + client = AsyncBaseClient( + ws_url="ws://test_url", tracer="tracker", ws_root_span_name="ws root span" + ) + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("ws root span", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_root_span_with_custom_context( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument +): + client = AsyncBaseClient( + ws_url="ws://test_url", tracer="tracker", ws_root_context={"value": "test"} + ) + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call( + "GraphQL Subscription", context={"value": "test"} + ) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_init_message( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): +): + client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("connection init", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "connection_init") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_subscribe_message( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): +): + client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws( + "subscription Abc(a: String, b: InputB) { value }", + variables={"a": "AAA", "b": {"valueB": 21}}, + ): + pass + + mocked_start_as_current_span.assert_any_call("connection init", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "connection_init") + + +@pytest.mark.parametrize( + "received_message", + [ + {"type": "next", "payload": {"data": "test_data"}}, + {"type": "complete"}, + {"type": "ping"}, + ], +) +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_received_message( + received_message, mocked_websocket, mocked_start_as_current_span +): + mocked_websocket.__aiter__.return_value.append(json.dumps(received_message)) + client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("received message", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", received_message["type"]) + + +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_received_error_message( + mocked_websocket, mocked_start_as_current_span +): + mocked_websocket.__aiter__.return_value.append( + json.dumps({"type": "error", "payload": [{"message": "error_message"}]}) + ) + client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") + + with pytest.raises(GraphQLClientGraphQLMultiError): + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("received message", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "error") diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index 6b52985c..49efb1a2 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -40,7 +40,9 @@ async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument WebSocketClientProtocol = Any # type: ignore Data = Any # type: ignore Origin = Any # type: ignore - Subprotocol = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") try: @@ -91,7 +93,9 @@ def __init__( ws_connection_init_payload: Optional[Dict[str, Any]] = None, tracer: Optional[Union[str, Tracer]] = None, root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,7 +112,9 @@ def __init__( get_tracer(tracer) if isinstance(tracer, str) else tracer ) self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name async def __aenter__(self: Self) -> Self: return self @@ -156,25 +162,15 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables ) + else: + generator = self._execute_ws(query=query, variables=variables) - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data + async for message in generator: + yield message async def _execute( self, query: str, variables: Optional[Dict[str, Any]] = None @@ -289,6 +285,29 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -409,3 +428,122 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None From 6ab391df656f8ce3f78470194067629c09186b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 6 Oct 2023 12:42:19 +0200 Subject: [PATCH 5/9] Add information about opentelemetry to readme --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index f530182b..ba352a70 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,20 @@ type = "Upload" ``` +### Opentelemetry + +Both default base clients support opt-in telemetry options. By default, it's disabled, but when the `opentelemetry-api` package is installed and the `tracer` argument is provided then the client will create spans with data about performed requests. + +Telemetry arguments handled by `BaseClient`: +- `tracer`: `Optional[Union[str, Tracer]] = None` - tracer object or name which will be passed to the `get_tracer` method +- `root_context`: `Optional[Context] = None` - optional context added to root span +- `root_span_name`: `str = "GraphQL Operation"` - name of root span + +`AsyncBaseClient` supports all arguments which `BaseClient` does, but also exposes additional arguments regarding websockets: +- `ws_root_context`: `Optional[Context] = None` - optional context added to root span for websocket connection +- `ws_root_span_name`: `str = "GraphQL Subscription"` - name of root span for websocket connection + + ## Custom scalars By default, not built-in scalars are represented as `typing.Any` in generated client. From 2a4b2eb6336550c24487d7e19c69e2fba04e6504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 6 Oct 2023 12:42:45 +0200 Subject: [PATCH 6/9] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c09c6bf..494d5a94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - Fixed parsing of unions and interfaces to always add `__typename` to generated result models. - Added escaping of enum values which are Python keywords by appending `_` to them. - Fixed `enums_module_name` option not being passed to generators. +- Added opentelemetry support to default base clients. ## 0.9.0 (2023-09-11) From 60ae660802da902dbd623409c7d5b259b8637602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Mon, 9 Oct 2023 18:20:52 +0200 Subject: [PATCH 7/9] Move opentelemetry support to new separate base clients. Add telemetry_client config option --- CHANGELOG.md | 2 +- README.md | 7 +- .../client_generators/constants.py | 15 + .../dependencies/async_base_client.py | 301 +-------- .../async_base_client_with_telemetry.py | 549 ++++++++++++++++ .../dependencies/base_client.py | 122 +--- .../base_client_with_telemetry.py | 270 ++++++++ ariadne_codegen/client_generators/package.py | 4 + ariadne_codegen/settings.py | 34 +- .../dependencies/test_async_base_client.py | 106 ---- .../test_async_base_client_with_telemetry.py | 594 ++++++++++++++++++ .../dependencies/test_base_client.py | 98 --- .../test_base_client_with_telemetry.py | 552 ++++++++++++++++ .../dependencies/test_websockets.py | 129 ---- .../test_websockets_with_telemetry.py | 348 ++++++++++ .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- .../expected_client/async_base_client.py | 301 +-------- tests/test_settings.py | 32 +- 26 files changed, 2693 insertions(+), 3480 deletions(-) create mode 100644 ariadne_codegen/client_generators/dependencies/async_base_client_with_telemetry.py create mode 100644 ariadne_codegen/client_generators/dependencies/base_client_with_telemetry.py create mode 100644 tests/client_generators/dependencies/test_async_base_client_with_telemetry.py create mode 100644 tests/client_generators/dependencies/test_base_client_with_telemetry.py create mode 100644 tests/client_generators/dependencies/test_websockets_with_telemetry.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 494d5a94..a7253bc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ - Fixed parsing of unions and interfaces to always add `__typename` to generated result models. - Added escaping of enum values which are Python keywords by appending `_` to them. - Fixed `enums_module_name` option not being passed to generators. -- Added opentelemetry support to default base clients. +- Added additional base clients with opentelemetry support. Added `telemetry_client` config option. ## 0.9.0 (2023-09-11) diff --git a/README.md b/README.md index ba352a70..b9ff8586 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ Optional settings: - `include_comments` (defaults to `"stable"`) - option which sets content of comments included at the top of every generated file. Valid choices are: `"none"` (no comments), `"timestamp"` (comment with generation timestamp), `"stable"` (comment contains a message that this is a generated file) - `convert_to_snake_case` (defaults to `true`) - a flag that specifies whether to convert fields and arguments names to snake case - `async_client` (defaults to `true`) - default generated client is `async`, change this to option `false` to generate synchronous client instead +- `telemetry_client` (defaults to `false`) - default base clients doesn't support opentelemetry, change this option to `true` to use base client with opentelemetry - `files_to_include` (defaults to `[]`) - list of files which will be copied into generated package - `plugins` (defaults to `[]`) - list of plugins to use during generation @@ -144,14 +145,14 @@ type = "Upload" ### Opentelemetry -Both default base clients support opt-in telemetry options. By default, it's disabled, but when the `opentelemetry-api` package is installed and the `tracer` argument is provided then the client will create spans with data about performed requests. +When config option `telemetry_client` is set to `true` then default, included base clients support opt-in telemetry options. By default, it's disabled, but when the `opentelemetry-api` package is installed and the `tracer` argument is provided then the client will create spans with data about performed requests. -Telemetry arguments handled by `BaseClient`: +Telemetry arguments handled by `BaseClientWithTelemetry`: - `tracer`: `Optional[Union[str, Tracer]] = None` - tracer object or name which will be passed to the `get_tracer` method - `root_context`: `Optional[Context] = None` - optional context added to root span - `root_span_name`: `str = "GraphQL Operation"` - name of root span -`AsyncBaseClient` supports all arguments which `BaseClient` does, but also exposes additional arguments regarding websockets: +`AsyncBaseClientWithTelemetry` supports all arguments which `BaseClientWithTelemetry` does, but also exposes additional arguments regarding websockets: - `ws_root_context`: `Optional[Context] = None` - optional context added to root span for websocket connection - `ws_root_span_name`: `str = "GraphQL Subscription"` - name of root span for websocket connection diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index fd66599e..324588d5 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -68,10 +68,25 @@ SKIP_DIRECTIVE_NAME = "skip" INCLUDE_DIRECTIVE_NAME = "include" + DEFAULT_ASYNC_BASE_CLIENT_PATH = ( Path(__file__).parent / "dependencies" / "async_base_client.py" ) +DEFAULT_ASYNC_BASE_CLIENT_NAME = "AsyncBaseClient" + +DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH = ( + Path(__file__).parent / "dependencies" / "async_base_client_with_telemetry.py" +) +DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_NAME = "AsyncBaseClientWithTelemetry" + DEFAULT_BASE_CLIENT_PATH = Path(__file__).parent / "dependencies" / "base_client.py" +DEFAULT_BASE_CLIENT_NAME = "BaseClient" + +DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH = ( + Path(__file__).parent / "dependencies" / "base_client_with_telemetry.py" +) +DEFAULT_BASE_CLIENT_WITH_TELEMETRY_NAME = "BaseClientWithTelemetry" + GRAPHQL_CLIENT_EXCEPTIONS_NAMES = [ "GraphQLClientError", diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_with_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_with_telemetry.py new file mode 100644 index 00000000..ce41d597 --- /dev/null +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_with_telemetry.py @@ -0,0 +1,549 @@ +import enum +import json +from typing import ( + IO, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQlClientInvalidResponseError, +) + +try: + from websockets.client import WebSocketClientProtocol, connect as ws_connect + from websockets.typing import Data, Origin, Subprotocol +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore + Data = Any # type: ignore + Origin = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClientWithTelemetry") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClientWithTelemetry: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: str = "GraphQL Operation", + ws_root_context: Optional[Context] = None, + ws_root_span_name: str = "GraphQL Subscription", + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name + self.ws_root_context = ws_root_context + self.ws_root_span_name = ws_root_span_name + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + if self.tracer: + return await self._execute_with_telemetry(query=query, variables=variables) + + return await self._execute(query=query, variables=variables) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQlClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ("data" not in response_json): + raise GraphQlClientInvalidResponseError(response=response) + + data = response_json["data"] + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + if self.tracer: + generator = self._execute_ws_with_telemetry( + query=query, variables=variables + ) + else: + generator = self._execute_ws(query=query, variables=variables) + + async for message in generator: + yield message + + async def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json(query=query, variables=processed_variables) + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + data = { + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post(url=self.url, data=data, files=files) + + async def _execute_json( + self, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + return await self.http_client.post( + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, + ) + + async def _execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None + + async def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return await self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + async def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + return await self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + async def _execute_json_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + return await self._execute_json(query=query, variables=variables) + + async def _execute_ws_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> AsyncIterator[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + self.ws_root_span_name, context=self.ws_root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init_with_telemetry( + root_span=root_span, + websocket=websocket, + ) + await self._send_subscribe_with_telemetry( + root_span=root_span, + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message_with_telemetry( + root_span=root_span, message=message, websocket=websocket + ) + if data: + yield data + + async def _send_connection_init_with_telemetry( + self, root_span: Span, websocket: WebSocketClientProtocol + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "connection init", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute( + "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value + ) + if self.ws_connection_init_payload: + span.set_attribute( + "payload", json.dumps(self.ws_connection_init_payload) + ) + + await self._send_connection_init(websocket=websocket) + + async def _send_subscribe_with_telemetry( + self, + root_span: Span, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + with self.tracer.start_as_current_span( # type: ignore + "subscribe", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + span.set_attribute("id", operation_id) + span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) + span.set_attribute("query", query) + if variables: + span.set_attribute( + "variables", + json.dumps(self._convert_dict_to_json_serializable(variables)), + ) + + await self._send_subscribe( + websocket=websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async def _handle_ws_message_with_telemetry( + self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + with self.tracer.start_as_current_span( # type: ignore + "received message", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + span.set_attribute("type", type_) + + if not type_ or type_ not in { + t.value for t in GraphQLTransportWSMessageType + }: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/ariadne_codegen/client_generators/dependencies/base_client.py b/ariadne_codegen/client_generators/dependencies/base_client.py index d6d271b8..fa981789 100644 --- a/ariadne_codegen/client_generators/dependencies/base_client.py +++ b/ariadne_codegen/client_generators/dependencies/base_client.py @@ -1,5 +1,5 @@ import json -from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, Union, cast +from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, cast import httpx from pydantic import BaseModel @@ -12,26 +12,6 @@ GraphQlClientInvalidResponseError, ) -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="BaseClient") @@ -41,21 +21,12 @@ def __init__( url: str = "", headers: Optional[Dict[str, str]] = None, http_client: Optional[httpx.Client] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: Optional[str] = None, ) -> None: self.url = url self.headers = headers self.http_client = http_client if http_client else httpx.Client(headers=headers) - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" - def __enter__(self: Self) -> Self: return self @@ -70,9 +41,17 @@ def __exit__( def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return self._execute_with_telemetry(query=query, variables=variables) - return self._execute(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> dict[str, Any]: if not response.is_success: @@ -98,21 +77,6 @@ def get_data(self, response: httpx.Response) -> dict[str, Any]: return cast(dict[str, Any], data) - def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return self._execute_multipart( - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return self._execute_json(query=query, variables=processed_variables) - def _process_variables( self, variables: Optional[Dict[str, Any]] ) -> Tuple[ @@ -206,65 +170,3 @@ def _execute_json(self, query: str, variables: Dict[str, Any]) -> httpx.Response ), headers={"Content-Type": "application/json"}, ) - - def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - - return self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - def _execute_json_with_telemetry( - self, root_span: Span, query: str, variables: Dict[str, Any] - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - - return self._execute_json(query=query, variables=variables) diff --git a/ariadne_codegen/client_generators/dependencies/base_client_with_telemetry.py b/ariadne_codegen/client_generators/dependencies/base_client_with_telemetry.py new file mode 100644 index 00000000..996ec374 --- /dev/null +++ b/ariadne_codegen/client_generators/dependencies/base_client_with_telemetry.py @@ -0,0 +1,270 @@ +import json +from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, Union, cast + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQlClientInvalidResponseError, +) + +try: + from opentelemetry.trace import ( # type: ignore[attr-defined] + Context, + Span, + Tracer, + get_tracer, + set_span_in_context, + ) +except ImportError: + Context = Any # type: ignore + Span = Any # type: ignore + Tracer = Any # type: ignore + + def get_tracer(*args, **kwargs) -> Tracer: # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + def set_span_in_context(*args, **kwargs): # type: ignore + raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") + + +Self = TypeVar("Self", bound="BaseClientWithTelemetry") + + +class BaseClientWithTelemetry: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.Client] = None, + tracer: Optional[Union[str, Tracer]] = None, + root_context: Optional[Context] = None, + root_span_name: Optional[str] = None, + ) -> None: + self.url = url + self.headers = headers + + self.http_client = http_client if http_client else httpx.Client(headers=headers) + + self.tracer: Optional[Tracer] = ( + get_tracer(tracer) if isinstance(tracer, str) else tracer + ) + self.root_context = root_context + self.root_span_name = root_span_name if root_span_name else "GraphQL Operation" + + def __enter__(self: Self) -> Self: + return self + + def __exit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + self.http_client.close() + + def execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + if self.tracer: + return self._execute_with_telemetry(query=query, variables=variables) + return self._execute(query=query, variables=variables) + + def get_data(self, response: httpx.Response) -> dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQlClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ("data" not in response_json): + raise GraphQlClientInvalidResponseError(response=response) + + data = response_json["data"] + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(dict[str, Any], data) + + def _execute( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return self._execute_json(query=query, variables=processed_variables) + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + def _execute_multipart( + self, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + data = { + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return self.http_client.post(url=self.url, data=data, files=files) + + def _execute_json(self, query: str, variables: Dict[str, Any]) -> httpx.Response: + return self.http_client.post( + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + headers={"Content-Type": "application/json"}, + ) + + def _execute_with_telemetry( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + self.root_span_name, context=self.root_context + ) as root_span: + root_span.set_attribute("component", "GraphQL Client") + + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return self._execute_multipart_with_telemetry( + root_span=root_span, + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) + + return self._execute_json_with_telemetry( + root_span=root_span, query=query, variables=processed_variables + ) + + def _execute_multipart_with_telemetry( + self, + root_span: Span, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "multipart request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + serialized_map = json.dumps(files_map, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + span.set_attribute("map", serialized_map) + + return self._execute_multipart( + query=query, variables=variables, files=files, files_map=files_map + ) + + def _execute_json_with_telemetry( + self, root_span: Span, query: str, variables: Dict[str, Any] + ) -> httpx.Response: + with self.tracer.start_as_current_span( # type: ignore + "json request", context=set_span_in_context(root_span) + ) as span: + span.set_attribute("component", "GraphQL Client") + + serialized_variables = json.dumps(variables, default=to_jsonable_python) + + span.set_attribute("query", query) + span.set_attribute("variables", serialized_variables) + + return self._execute_json(query=query, variables=variables) diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index 4740d3a8..466e5d58 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -17,7 +17,9 @@ BASE_MODEL_FILE_PATH, BASE_MODEL_IMPORT, DEFAULT_ASYNC_BASE_CLIENT_PATH, + DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, DEFAULT_BASE_CLIENT_PATH, + DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, EXCEPTIONS_FILE_PATH, GRAPHQL_CLIENT_EXCEPTIONS_NAMES, UNSET_IMPORT, @@ -172,6 +174,8 @@ def _include_exceptions(self): if self.base_client_file_path in ( DEFAULT_ASYNC_BASE_CLIENT_PATH, DEFAULT_BASE_CLIENT_PATH, + DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, + DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, ): self.files_to_include.append(EXCEPTIONS_FILE_PATH) self.init_generator.add_import( diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 024a6149..2c10e1ec 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -7,8 +7,14 @@ from typing import Dict, List from .client_generators.constants import ( + DEFAULT_ASYNC_BASE_CLIENT_NAME, DEFAULT_ASYNC_BASE_CLIENT_PATH, + DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_NAME, + DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, + DEFAULT_BASE_CLIENT_NAME, DEFAULT_BASE_CLIENT_PATH, + DEFAULT_BASE_CLIENT_WITH_TELEMETRY_NAME, + DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, ) from .client_generators.scalars import ScalarData from .exceptions import InvalidConfiguration @@ -60,6 +66,7 @@ class ClientSettings(BaseSettings): include_comments: CommentsStrategy = field(default=CommentsStrategy.STABLE) convert_to_snake_case: bool = True async_client: bool = True + telemetry_client: bool = False files_to_include: List[str] = field(default_factory=list) scalars: Dict[str, ScalarData] = field(default_factory=dict) @@ -103,13 +110,28 @@ def __post_init__(self): assert_path_is_valid_file(file_path) def _set_default_base_client_data(self): + default_clients_map = { + (True, True): ( + DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, + DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_NAME, + ), + (True, False): ( + DEFAULT_ASYNC_BASE_CLIENT_PATH, + DEFAULT_ASYNC_BASE_CLIENT_NAME, + ), + (False, True): ( + DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, + DEFAULT_BASE_CLIENT_WITH_TELEMETRY_NAME, + ), + (False, False): ( + DEFAULT_BASE_CLIENT_PATH, + DEFAULT_BASE_CLIENT_NAME, + ), + } if not self.base_client_name and not self.base_client_file_path: - if self.async_client: - self.base_client_file_path = DEFAULT_ASYNC_BASE_CLIENT_PATH.as_posix() - self.base_client_name = "AsyncBaseClient" - else: - self.base_client_file_path = DEFAULT_BASE_CLIENT_PATH.as_posix() - self.base_client_name = "BaseClient" + path, name = default_clients_map[(self.async_client, self.telemetry_client)] + self.base_client_name = name + self.base_client_file_path = path.as_posix() @property def schema_source(self) -> str: diff --git a/tests/client_generators/dependencies/test_async_base_client.py b/tests/client_generators/dependencies/test_async_base_client.py index 277219a0..b7281164 100644 --- a/tests/client_generators/dependencies/test_async_base_client.py +++ b/tests/client_generators/dependencies/test_async_base_client.py @@ -1,7 +1,6 @@ import json from datetime import datetime from typing import Any, Optional -from unittest.mock import ANY import httpx import pytest @@ -474,108 +473,3 @@ async def test_base_client_used_as_context_manager_closes_http_client(mocker): await base_client.execute("") assert fake_client.aclose.called - - -@pytest.fixture -def mocker_get_tracer(mocker): - return mocker.patch( - "ariadne_codegen.client_generators.dependencies.async_base_client.get_tracer" - ) - - -@pytest.fixture -def mocked_start_as_current_span(mocker_get_tracer): - return mocker_get_tracer.return_value.start_as_current_span - - -@pytest.mark.asyncio -async def test_async_base_client_with_given_tracker_str_uses_global_tracker( - mocker_get_tracer, -): - AsyncBaseClient(url="http://base_url", tracer="tracker name") - - assert mocker_get_tracer.call_count == 1 - - -@pytest.mark.asyncio -async def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): - httpx_mock.add_response() - client = AsyncBaseClient(url="http://base_url", tracer="tracker") - - await client.execute("query GetHello { hello }") - - mocked_start_as_current_span.assert_any_call("GraphQL Operation", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - - -@pytest.mark.asyncio -async def test_execute_creates_root_span_with_custom_name( - httpx_mock, mocked_start_as_current_span -): - httpx_mock.add_response() - client = AsyncBaseClient( - url="http://base_url", tracer="tracker", root_span_name="root_span" - ) - - await client.execute("query GetHello { hello }") - - mocked_start_as_current_span.assert_any_call("root_span", context=ANY) - - -@pytest.mark.asyncio -async def test_execute_creates_root_span_with_custom_context( - httpx_mock, mocked_start_as_current_span -): - httpx_mock.add_response() - client = AsyncBaseClient( - url="http://base_url", tracer="tracker", root_context={"abc": 123} - ) - - await client.execute("query GetHello { hello }") - - mocked_start_as_current_span.assert_any_call( - "GraphQL Operation", context={"abc": 123} - ) - - -@pytest.mark.asyncio -async def test_execute_creates_span_for_json_http_request( - httpx_mock, mocked_start_as_current_span -): - httpx_mock.add_response() - client = AsyncBaseClient(url="http://base_url", tracer="tracker") - - await client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) - - mocked_start_as_current_span.assert_any_call("json request", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - span.set_attribute.assert_any_call("query", "query GetHello { hello }") - span.set_attribute.assert_any_call( - "variables", json.dumps({"a": 1, "b": {"bb": 2}}) - ) - - -@pytest.mark.asyncio -async def test_execute_creates_span_for_multipart_request( - httpx_mock, txt_file, mocked_start_as_current_span -): - httpx_mock.add_response() - client = AsyncBaseClient(url="http://base_url", tracer="tracker") - - await client.execute( - "query Abc($file: Upload!) { abc(file: $file) }", - {"file": txt_file, "a": 1.0, "b": {"bb": 2}}, - ) - - mocked_start_as_current_span.assert_any_call("multipart request", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - span.set_attribute.assert_any_call( - "query", "query Abc($file: Upload!) { abc(file: $file) }" - ) - span.set_attribute.assert_any_call( - "variables", json.dumps({"file": None, "a": 1.0, "b": {"bb": 2}}) - ) - span.set_attribute.assert_any_call("map", json.dumps({"0": ["variables.file"]})) diff --git a/tests/client_generators/dependencies/test_async_base_client_with_telemetry.py b/tests/client_generators/dependencies/test_async_base_client_with_telemetry.py new file mode 100644 index 00000000..1f451169 --- /dev/null +++ b/tests/client_generators/dependencies/test_async_base_client_with_telemetry.py @@ -0,0 +1,594 @@ +import json +from datetime import datetime +from typing import Any, Optional +from unittest.mock import ANY + +import httpx +import pytest + +from ariadne_codegen.client_generators.dependencies.async_base_client_with_telemetry import ( # pylint: disable=line-too-long + AsyncBaseClientWithTelemetry, +) +from ariadne_codegen.client_generators.dependencies.base_model import UNSET, BaseModel +from ariadne_codegen.client_generators.dependencies.exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQlClientInvalidResponseError, +) + +from ...utils import decode_multipart_request + + +@pytest.mark.asyncio +async def test_execute_sends_post_to_correct_url_with_correct_payload(httpx_mock): + httpx_mock.add_response() + + client = AsyncBaseClientWithTelemetry(url="http://base_url/endpoint") + query_str = """ + query Abc($v: String!) { + abc(v: $v) { + field1 + } + } + """ + + await client.execute(query_str, {"v": "Xyz"}) + + request = httpx_mock.get_request() + assert request.url == "http://base_url/endpoint" + content = json.loads(request.content) + assert content == {"query": query_str, "variables": {"v": "Xyz"}} + + +@pytest.mark.asyncio +async def test_execute_parses_pydantic_variables_before_sending(httpx_mock): + class TestModel1(BaseModel): + a: int + + class TestModel2(BaseModel): + nested: TestModel1 + + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($v1: TestModel1!, $v2: TestModel2) { + abc(v1: $v1, v2: $v2){ + field1 + } + } + """ + + await client.execute( + query_str, {"v1": TestModel1(a=5), "v2": TestModel2(nested=TestModel1(a=10))} + ) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": {"v1": {"a": 5}, "v2": {"nested": {"a": 10}}}, + } + + +@pytest.mark.asyncio +async def test_execute_correctly_parses_top_level_list_variables(httpx_mock): + class TestModel1(BaseModel): + a: int + + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($v1: [[TestModel1!]!]!) { + abc(v1: $v1){ + field1 + } + } + """ + + await client.execute( + query_str, + { + "v1": [[TestModel1(a=1), TestModel1(a=2)]], + }, + ) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": {"v1": [[{"a": 1}, {"a": 2}]]}, + } + assert not any(isinstance(x, BaseModel) for x in content["variables"]["v1"][0]) + + +@pytest.mark.asyncio +async def test_execute_sends_payload_without_unset_arguments(httpx_mock): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($arg1: TestInputA, $arg2: String, $arg3: Float, $arg4: Int!) { + abc(arg1: $arg1, arg2: $arg2, arg3: $arg3, arg4: $arg4){ + field1 + } + } + """ + + await client.execute( + query_str, {"arg1": UNSET, "arg2": UNSET, "arg3": None, "arg4": 2} + ) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": {"arg3": None, "arg4": 2}, + } + + +@pytest.mark.asyncio +async def test_execute_sends_payload_without_unset_input_fields(httpx_mock): + class TestInputB(BaseModel): + required_b: str + optional_b: Optional[str] = None + + class TestInputA(BaseModel): + required_a: str + optional_a: Optional[str] = None + input_b1: Optional[TestInputB] = None + input_b2: Optional[TestInputB] = None + input_b3: Optional[TestInputB] = None + + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($arg: TestInputB) { + abc(arg: $arg){ + field1 + } + } + """ + + await client.execute( + query_str, + { + "arg": TestInputA( + required_a="a", input_b1=TestInputB(required_b="b"), input_b3=None + ) + }, + ) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": { + "arg": { + "required_a": "a", + "input_b1": {"required_b": "b"}, + "input_b3": None, + } + }, + } + + +@pytest.mark.asyncio +async def test_execute_sends_payload_with_serialized_datetime_without_exception( + httpx_mock, +): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url") + query_str = "query Abc($arg: DATETIME) { abc }" + arg_value = datetime(2023, 12, 31, 10, 15) + + await client.execute(query_str, {"arg": arg_value}) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content["variables"]["arg"] == arg_value.isoformat() + + +@pytest.mark.asyncio +async def test_execute_sends_request_with_correct_content_type(httpx_mock): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url") + + await client.execute("query Abc { abc }", {}) + + request = httpx_mock.get_request() + assert request.headers["Content-Type"] == "application/json" + + +@pytest.mark.asyncio +async def test_execute_sends_request_with_extra_headers_and_correct_content_type( + httpx_mock, +): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry( + url="http://base_url", headers={"h_key": "h_value"} + ) + + await client.execute("query Abc { abc }", {}) + + request = httpx_mock.get_request() + assert request.headers["h_key"] == "h_value" + assert request.headers["Content-Type"] == "application/json" + + +@pytest.mark.asyncio +async def test_execute_sends_file_with_multipart_form_data_content_type( + httpx_mock, txt_file +): + httpx_mock.add_response() + + client = AsyncBaseClientWithTelemetry(url="http://base_url") + await client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", {"file": txt_file} + ) + + request = httpx_mock.get_request() + assert "multipart/form-data" in request.headers["Content-Type"] + + +@pytest.mark.asyncio +async def test_execute_sends_file_as_multipart_request(httpx_mock, txt_file): + httpx_mock.add_response() + query_str = "query Abc($file: Upload!) { abc(file: $file) }" + + client = AsyncBaseClientWithTelemetry(url="http://base_url") + await client.execute(query_str, {"file": txt_file}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == {"query": query_str, "variables": {"file": None}} + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.file"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + +@pytest.mark.asyncio +async def test_execute_sends_file_from_memory(httpx_mock, in_memory_txt_file): + httpx_mock.add_response() + query_str = "query Abc($file: Upload!) { abc(file: $file) }" + + client = AsyncBaseClientWithTelemetry(url="http://base_url") + await client.execute(query_str, {"file": in_memory_txt_file}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == {"query": query_str, "variables": {"file": None}} + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.file"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"in_memory.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"123456" + + +@pytest.mark.asyncio +async def test_execute_sends_multiple_files(httpx_mock, txt_file, png_file): + httpx_mock.add_response() + query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" + + client = AsyncBaseClientWithTelemetry(url="http://base_url") + await client.execute(query_str, {"files": [txt_file, png_file]}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == { + "query": query_str, + "variables": {"files": [None, None]}, + } + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.files.0"], "1": ["variables.files.1"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + assert sent_parts["1"] + assert sent_parts["1"].headers[b"Content-Type"] == b"image/png" + assert b"png_file.png" in sent_parts["1"].headers[b"Content-Disposition"] + assert sent_parts["1"].content == b"image_content" + + +@pytest.mark.asyncio +async def test_execute_sends_nested_file(httpx_mock, txt_file): + class InputType(BaseModel): + file_: Any + + httpx_mock.add_response() + query_str = "query Abc($input: InputType!) { abc(input: $input) }" + + client = AsyncBaseClientWithTelemetry(url="http://base_url") + await client.execute(query_str, {"input": InputType(file_=txt_file)}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == { + "query": query_str, + "variables": {"input": {"file_": None}}, + } + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.input.file_"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + +@pytest.mark.asyncio +async def test_execute_sends_each_file_only_once(httpx_mock, txt_file): + httpx_mock.add_response() + query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" + + client = AsyncBaseClientWithTelemetry(url="http://base_url") + await client.execute(query_str, {"files": [txt_file, txt_file]}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == { + "query": query_str, + "variables": {"files": [None, None]}, + } + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.files.0", "variables.files.1"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + +@pytest.mark.parametrize( + "status_code, response_content", + [ + (401, {"msg": "Unauthorized"}), + (403, {"msg": "Forbidden"}), + (404, {"msg": "Not Found"}), + (500, {"msg": "Internal Server Error"}), + ], +) +def test_get_data_raises_graphql_client_http_error( + mocker, status_code, response_content +): + client = AsyncBaseClientWithTelemetry( + url="base_url", http_client=mocker.MagicMock() + ) + response = httpx.Response( + status_code=status_code, content=json.dumps(response_content) + ) + + with pytest.raises(GraphQLClientHttpError) as exc: + client.get_data(response) + assert exc.status_code == status_code + assert exc.response == response + + +@pytest.mark.parametrize("response_content", ["invalid_json", {"not_data": ""}, ""]) +def test_get_data_raises_graphql_client_invalid_response_error( + mocker, response_content +): + client = AsyncBaseClientWithTelemetry( + url="base_url", http_client=mocker.MagicMock() + ) + response = httpx.Response(status_code=200, content=json.dumps(response_content)) + + with pytest.raises(GraphQlClientInvalidResponseError) as exc: + client.get_data(response) + assert exc.response == response + + +@pytest.mark.parametrize( + "response_content", + [ + { + "data": {}, + "errors": [ + { + "message": "Error message", + "locations": [{"line": 6, "column": 7}], + "path": ["field1", "field2", 1, "id"], + } + ], + }, + { + "data": {}, + "errors": [ + { + "message": "Error message type A", + "locations": [{"line": 6, "column": 7}], + "path": ["field1", "field2", 1, "id"], + }, + { + "message": "Error message type B", + "locations": [{"line": 6, "column": 7}], + "path": ["field1", "field2", 1, "id"], + }, + ], + }, + ], +) +def test_get_data_raises_graphql_client_graphql_multi_error(mocker, response_content): + client = AsyncBaseClientWithTelemetry( + url="base_url", http_client=mocker.MagicMock() + ) + + with pytest.raises(GraphQLClientGraphQLMultiError): + client.get_data( + httpx.Response(status_code=200, content=json.dumps(response_content)) + ) + + +@pytest.mark.parametrize( + "response_content", + [{"errors": [], "data": {}}, {"errors": None, "data": {}}, {"data": {}}], +) +def test_get_data_doesnt_raise_exception(mocker, response_content): + client = AsyncBaseClientWithTelemetry( + url="base_url", http_client=mocker.MagicMock() + ) + + data = client.get_data( + httpx.Response(status_code=200, content=json.dumps(response_content)) + ) + + assert data == response_content["data"] + + +@pytest.mark.asyncio +async def test_base_client_used_as_context_manager_closes_http_client(mocker): + fake_client = mocker.AsyncMock() + async with AsyncBaseClientWithTelemetry( + url="base_url", http_client=fake_client + ) as base_client: + await base_client.execute("") + + assert fake_client.aclose.called + + +@pytest.fixture +def mocker_get_tracer(mocker): + return mocker.patch( + "ariadne_codegen.client_generators.dependencies." + "async_base_client_with_telemetry.get_tracer" + ) + + +@pytest.fixture +def mocked_start_as_current_span(mocker_get_tracer): + return mocker_get_tracer.return_value.start_as_current_span + + +@pytest.mark.asyncio +async def test_async_base_client_with_given_tracker_str_uses_global_tracker( + mocker_get_tracer, +): + AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker name") + + assert mocker_get_tracer.call_count == 1 + + +@pytest.mark.asyncio +async def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker") + + await client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("GraphQL Operation", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_creates_root_span_with_custom_name( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry( + url="http://base_url", tracer="tracker", root_span_name="root_span" + ) + + await client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("root_span", context=ANY) + + +@pytest.mark.asyncio +async def test_execute_creates_root_span_with_custom_context( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry( + url="http://base_url", tracer="tracker", root_context={"abc": 123} + ) + + await client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call( + "GraphQL Operation", context={"abc": 123} + ) + + +@pytest.mark.asyncio +async def test_execute_creates_span_for_json_http_request( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker") + + await client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) + + mocked_start_as_current_span.assert_any_call("json request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call("query", "query GetHello { hello }") + span.set_attribute.assert_any_call( + "variables", json.dumps({"a": 1, "b": {"bb": 2}}) + ) + + +@pytest.mark.asyncio +async def test_execute_creates_span_for_multipart_request( + httpx_mock, txt_file, mocked_start_as_current_span +): + httpx_mock.add_response() + client = AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker") + + await client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file, "a": 1.0, "b": {"bb": 2}}, + ) + + mocked_start_as_current_span.assert_any_call("multipart request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call( + "query", "query Abc($file: Upload!) { abc(file: $file) }" + ) + span.set_attribute.assert_any_call( + "variables", json.dumps({"file": None, "a": 1.0, "b": {"bb": 2}}) + ) + span.set_attribute.assert_any_call("map", json.dumps({"0": ["variables.file"]})) diff --git a/tests/client_generators/dependencies/test_base_client.py b/tests/client_generators/dependencies/test_base_client.py index aba53ff7..24fd4087 100644 --- a/tests/client_generators/dependencies/test_base_client.py +++ b/tests/client_generators/dependencies/test_base_client.py @@ -1,7 +1,6 @@ import json from datetime import datetime from typing import Any, Optional -from unittest.mock import ANY import httpx import pytest @@ -446,100 +445,3 @@ def test_base_client_used_as_context_manager_closes_http_client(mocker): base_client.execute("") assert fake_client.close.called - - -@pytest.fixture -def mocker_get_tracer(mocker): - return mocker.patch( - "ariadne_codegen.client_generators.dependencies.base_client.get_tracer" - ) - - -@pytest.fixture -def mocked_start_as_current_span(mocker_get_tracer): - return mocker_get_tracer.return_value.start_as_current_span - - -def test_base_client_with_given_tracker_str_uses_global_tracker(mocker_get_tracer): - BaseClient(url="http://base_url", tracer="tracker name") - - assert mocker_get_tracer.call_count == 1 - - -def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): - httpx_mock.add_response() - client = BaseClient(url="http://base_url", tracer="tracker") - - client.execute("query GetHello { hello }") - - mocked_start_as_current_span.assert_any_call("GraphQL Operation", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - - -def test_execute_creates_root_span_with_custom_name( - httpx_mock, mocked_start_as_current_span -): - httpx_mock.add_response() - client = BaseClient( - url="http://base_url", tracer="tracker", root_span_name="root_span" - ) - - client.execute("query GetHello { hello }") - - mocked_start_as_current_span.assert_any_call("root_span", context=ANY) - - -def test_execute_creates_root_span_with_custom_context( - httpx_mock, mocked_start_as_current_span -): - httpx_mock.add_response() - client = BaseClient( - url="http://base_url", tracer="tracker", root_context={"abc": 123} - ) - - client.execute("query GetHello { hello }") - - mocked_start_as_current_span.assert_any_call( - "GraphQL Operation", context={"abc": 123} - ) - - -def test_execute_creates_span_for_json_http_request( - httpx_mock, mocked_start_as_current_span -): - httpx_mock.add_response() - client = BaseClient(url="http://base_url", tracer="tracker") - - client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) - - mocked_start_as_current_span.assert_any_call("json request", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - span.set_attribute.assert_any_call("query", "query GetHello { hello }") - span.set_attribute.assert_any_call( - "variables", json.dumps({"a": 1, "b": {"bb": 2}}) - ) - - -def test_execute_creates_span_for_multipart_request( - httpx_mock, txt_file, mocked_start_as_current_span -): - httpx_mock.add_response() - client = BaseClient(url="http://base_url", tracer="tracker") - - client.execute( - "query Abc($file: Upload!) { abc(file: $file) }", - {"file": txt_file, "a": 1.0, "b": {"bb": 2}}, - ) - - mocked_start_as_current_span.assert_any_call("multipart request", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - span.set_attribute.assert_any_call( - "query", "query Abc($file: Upload!) { abc(file: $file) }" - ) - span.set_attribute.assert_any_call( - "variables", json.dumps({"file": None, "a": 1.0, "b": {"bb": 2}}) - ) - span.set_attribute.assert_any_call("map", json.dumps({"0": ["variables.file"]})) diff --git a/tests/client_generators/dependencies/test_base_client_with_telemetry.py b/tests/client_generators/dependencies/test_base_client_with_telemetry.py new file mode 100644 index 00000000..5dc1f127 --- /dev/null +++ b/tests/client_generators/dependencies/test_base_client_with_telemetry.py @@ -0,0 +1,552 @@ +import json +from datetime import datetime +from typing import Any, Optional +from unittest.mock import ANY + +import httpx +import pytest + +from ariadne_codegen.client_generators.dependencies.base_client_with_telemetry import ( + BaseClientWithTelemetry, +) +from ariadne_codegen.client_generators.dependencies.base_model import UNSET, BaseModel +from ariadne_codegen.client_generators.dependencies.exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQlClientInvalidResponseError, +) + +from ...utils import decode_multipart_request + + +def test_execute_sends_post_to_correct_url_with_correct_payload(httpx_mock): + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url/endpoint") + query_str = """ + query Abc($v: String!) { + abc(v: $v) { + field1 + } + } + """ + + client.execute(query_str, {"v": "Xyz"}) + + request = httpx_mock.get_request() + assert request.url == "http://base_url/endpoint" + content = json.loads(request.content) + assert content == {"query": query_str, "variables": {"v": "Xyz"}} + + +def test_execute_parses_pydantic_variables_before_sending(httpx_mock): + class TestModel1(BaseModel): + a: int + + class TestModel2(BaseModel): + nested: TestModel1 + + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($v1: TestModel1!, $v2: TestModel2) { + abc(v1: $v1, v2: $v2){ + field1 + } + } + """ + + client.execute( + query_str, {"v1": TestModel1(a=5), "v2": TestModel2(nested=TestModel1(a=10))} + ) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": {"v1": {"a": 5}, "v2": {"nested": {"a": 10}}}, + } + + +def test_execute_correctly_parses_top_level_list_variables(httpx_mock): + class TestModel1(BaseModel): + a: int + + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($v1: [[TestModel1!]!]!) { + abc(v1: $v1){ + field1 + } + } + """ + + client.execute( + query_str, + { + "v1": [[TestModel1(a=1), TestModel1(a=2)]], + }, + ) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": {"v1": [[{"a": 1}, {"a": 2}]]}, + } + assert not any(isinstance(x, BaseModel) for x in content["variables"]["v1"][0]) + + +def test_execute_sends_payload_without_unset_arguments(httpx_mock): + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($arg1: TestInputA, $arg2: String, $arg3: Float, $arg4: Int!) { + abc(arg1: $arg1, arg2: $arg2, arg3: $arg3, arg4: $arg4){ + field1 + } + } + """ + + client.execute(query_str, {"arg1": UNSET, "arg2": UNSET, "arg3": None, "arg4": 2}) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": {"arg3": None, "arg4": 2}, + } + + +def test_execute_sends_payload_without_unset_input_fields(httpx_mock): + class TestInputB(BaseModel): + required_b: str + optional_b: Optional[str] = None + + class TestInputA(BaseModel): + required_a: str + optional_a: Optional[str] = None + input_b1: Optional[TestInputB] = None + input_b2: Optional[TestInputB] = None + input_b3: Optional[TestInputB] = None + + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url") + query_str = """ + query Abc($arg: TestInputB) { + abc(arg: $arg){ + field1 + } + } + """ + + client.execute( + query_str, + { + "arg": TestInputA( + required_a="a", input_b1=TestInputB(required_b="b"), input_b3=None + ) + }, + ) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content == { + "query": query_str, + "variables": { + "arg": { + "required_a": "a", + "input_b1": {"required_b": "b"}, + "input_b3": None, + } + }, + } + + +def test_execute_sends_payload_with_serialized_datetime_without_exception(httpx_mock): + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url") + query_str = "query Abc($arg: DATETIME) { abc }" + arg_value = datetime(2023, 12, 31, 10, 15) + + client.execute(query_str, {"arg": arg_value}) + + request = httpx_mock.get_request() + content = json.loads(request.content) + assert content["variables"]["arg"] == arg_value.isoformat() + + +def test_execute_sends_request_with_correct_content_type(httpx_mock): + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url") + + client.execute("query Abc { abc }", {}) + + request = httpx_mock.get_request() + assert request.headers["Content-Type"] == "application/json" + + +def test_execute_sends_request_with_extra_headers_and_correct_content_type(httpx_mock): + httpx_mock.add_response() + client = BaseClientWithTelemetry( + url="http://base_url", headers={"h_key": "h_value"} + ) + + client.execute("query Abc { abc }", {}) + + request = httpx_mock.get_request() + assert request.headers["h_key"] == "h_value" + assert request.headers["Content-Type"] == "application/json" + + +def test_execute_sends_file_with_multipart_form_data_content_type(httpx_mock, txt_file): + httpx_mock.add_response() + + client = BaseClientWithTelemetry(url="http://base_url") + client.execute("query Abc($file: Upload!) { abc(file: $file) }", {"file": txt_file}) + + request = httpx_mock.get_request() + assert "multipart/form-data" in request.headers["Content-Type"] + + +def test_execute_sends_file_as_multipart_request(httpx_mock, txt_file): + httpx_mock.add_response() + query_str = "query Abc($file: Upload!) { abc(file: $file) }" + + client = BaseClientWithTelemetry(url="http://base_url") + client.execute(query_str, {"file": txt_file}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == {"query": query_str, "variables": {"file": None}} + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.file"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + +def test_execute_sends_file_from_memory(httpx_mock, in_memory_txt_file): + httpx_mock.add_response() + query_str = "query Abc($file: Upload!) { abc(file: $file) }" + + client = BaseClientWithTelemetry(url="http://base_url") + client.execute(query_str, {"file": in_memory_txt_file}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == {"query": query_str, "variables": {"file": None}} + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.file"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"in_memory.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"123456" + + +def test_execute_sends_multiple_files(httpx_mock, txt_file, png_file): + httpx_mock.add_response() + query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" + + client = BaseClientWithTelemetry(url="http://base_url") + client.execute(query_str, {"files": [txt_file, png_file]}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == { + "query": query_str, + "variables": {"files": [None, None]}, + } + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.files.0"], "1": ["variables.files.1"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + assert sent_parts["1"] + assert sent_parts["1"].headers[b"Content-Type"] == b"image/png" + assert b"png_file.png" in sent_parts["1"].headers[b"Content-Disposition"] + assert sent_parts["1"].content == b"image_content" + + +def test_execute_sends_nested_file(httpx_mock, txt_file): + class InputType(BaseModel): + file_: Any + + httpx_mock.add_response() + query_str = "query Abc($input: InputType!) { abc(input: $input) }" + + client = BaseClientWithTelemetry(url="http://base_url") + client.execute(query_str, {"input": InputType(file_=txt_file)}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == { + "query": query_str, + "variables": {"input": {"file_": None}}, + } + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.input.file_"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + +def test_execute_sends_each_file_only_once(httpx_mock, txt_file): + httpx_mock.add_response() + query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" + + client = BaseClientWithTelemetry(url="http://base_url") + client.execute(query_str, {"files": [txt_file, txt_file]}) + + request = httpx_mock.get_request() + request.read() + assert "multipart/form-data" in request.headers["Content-Type"] + sent_parts = decode_multipart_request(request) + + assert sent_parts["operations"] + decoded_operations = json.loads(sent_parts["operations"].content) + assert decoded_operations == { + "query": query_str, + "variables": {"files": [None, None]}, + } + + assert sent_parts["map"] + decoded_map = json.loads(sent_parts["map"].content) + assert decoded_map == {"0": ["variables.files.0", "variables.files.1"]} + + assert sent_parts["0"] + assert sent_parts["0"].headers[b"Content-Type"] == b"text/plain" + assert b"txt_file.txt" in sent_parts["0"].headers[b"Content-Disposition"] + assert sent_parts["0"].content == b"abcdefgh" + + +@pytest.mark.parametrize( + "status_code, response_content", + [ + (401, {"msg": "Unauthorized"}), + (403, {"msg": "Forbidden"}), + (404, {"msg": "Not Found"}), + (500, {"msg": "Internal Server Error"}), + ], +) +def test_get_data_raises_graphql_client_http_error( + mocker, status_code, response_content +): + client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + response = httpx.Response( + status_code=status_code, content=json.dumps(response_content) + ) + + with pytest.raises(GraphQLClientHttpError) as exc: + client.get_data(response) + assert exc.status_code == status_code + assert exc.response == response + + +@pytest.mark.parametrize("response_content", ["invalid_json", {"not_data": ""}, ""]) +def test_get_data_raises_graphql_client_invalid_response_error( + mocker, response_content +): + client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + response = httpx.Response(status_code=200, content=json.dumps(response_content)) + + with pytest.raises(GraphQlClientInvalidResponseError) as exc: + client.get_data(response) + assert exc.response == response + + +@pytest.mark.parametrize( + "response_content", + [ + { + "data": {}, + "errors": [ + { + "message": "Error message", + "locations": [{"line": 6, "column": 7}], + "path": ["field1", "field2", 1, "id"], + } + ], + }, + { + "data": {}, + "errors": [ + { + "message": "Error message type A", + "locations": [{"line": 6, "column": 7}], + "path": ["field1", "field2", 1, "id"], + }, + { + "message": "Error message type B", + "locations": [{"line": 6, "column": 7}], + "path": ["field1", "field2", 1, "id"], + }, + ], + }, + ], +) +def test_get_data_raises_graphql_client_graphql_multi_error(mocker, response_content): + client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + + with pytest.raises(GraphQLClientGraphQLMultiError): + client.get_data( + httpx.Response(status_code=200, content=json.dumps(response_content)) + ) + + +@pytest.mark.parametrize( + "response_content", + [{"errors": [], "data": {}}, {"errors": None, "data": {}}, {"data": {}}], +) +def test_get_data_doesnt_raise_exception(mocker, response_content): + client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + + data = client.get_data( + httpx.Response(status_code=200, content=json.dumps(response_content)) + ) + + assert data == response_content["data"] + + +def test_base_client_used_as_context_manager_closes_http_client(mocker): + fake_client = mocker.MagicMock() + with BaseClientWithTelemetry( + url="base_url", http_client=fake_client + ) as base_client: + base_client.execute("") + + assert fake_client.close.called + + +@pytest.fixture +def mocker_get_tracer(mocker): + return mocker.patch( + "ariadne_codegen.client_generators.dependencies." + "base_client_with_telemetry.get_tracer" + ) + + +@pytest.fixture +def mocked_start_as_current_span(mocker_get_tracer): + return mocker_get_tracer.return_value.start_as_current_span + + +def test_base_client_with_given_tracker_str_uses_global_tracker(mocker_get_tracer): + BaseClientWithTelemetry(url="http://base_url", tracer="tracker name") + + assert mocker_get_tracer.call_count == 1 + + +def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url", tracer="tracker") + + client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("GraphQL Operation", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +def test_execute_creates_root_span_with_custom_name( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClientWithTelemetry( + url="http://base_url", tracer="tracker", root_span_name="root_span" + ) + + client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call("root_span", context=ANY) + + +def test_execute_creates_root_span_with_custom_context( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClientWithTelemetry( + url="http://base_url", tracer="tracker", root_context={"abc": 123} + ) + + client.execute("query GetHello { hello }") + + mocked_start_as_current_span.assert_any_call( + "GraphQL Operation", context={"abc": 123} + ) + + +def test_execute_creates_span_for_json_http_request( + httpx_mock, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url", tracer="tracker") + + client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) + + mocked_start_as_current_span.assert_any_call("json request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call("query", "query GetHello { hello }") + span.set_attribute.assert_any_call( + "variables", json.dumps({"a": 1, "b": {"bb": 2}}) + ) + + +def test_execute_creates_span_for_multipart_request( + httpx_mock, txt_file, mocked_start_as_current_span +): + httpx_mock.add_response() + client = BaseClientWithTelemetry(url="http://base_url", tracer="tracker") + + client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file, "a": 1.0, "b": {"bb": 2}}, + ) + + mocked_start_as_current_span.assert_any_call("multipart request", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + span.set_attribute.assert_any_call( + "query", "query Abc($file: Upload!) { abc(file: $file) }" + ) + span.set_attribute.assert_any_call( + "variables", json.dumps({"file": None, "a": 1.0, "b": {"bb": 2}}) + ) + span.set_attribute.assert_any_call("map", json.dumps({"0": ["variables.file"]})) diff --git a/tests/client_generators/dependencies/test_websockets.py b/tests/client_generators/dependencies/test_websockets.py index d6e9dd10..995e2a30 100644 --- a/tests/client_generators/dependencies/test_websockets.py +++ b/tests/client_generators/dependencies/test_websockets.py @@ -1,5 +1,4 @@ import json -from unittest.mock import ANY import pytest @@ -214,131 +213,3 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type with pytest.raises(GraphQLClientGraphQLMultiError): async for _ in AsyncBaseClient().execute_ws(""): pass - - -@pytest.fixture -def mocked_start_as_current_span(mocker): - mocker_get_tracer = mocker.patch( - "ariadne_codegen.client_generators.dependencies.async_base_client.get_tracer" - ) - return mocker_get_tracer.return_value.start_as_current_span - - -@pytest.mark.asyncio -async def test_execute_ws_creates_root_span( - mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument -): - client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") - - async for _ in client.execute_ws(""): - pass - - mocked_start_as_current_span.assert_any_call("GraphQL Subscription", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - - -@pytest.mark.asyncio -async def test_execute_ws_creates_root_span_with_custom_name( - mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument -): - client = AsyncBaseClient( - ws_url="ws://test_url", tracer="tracker", ws_root_span_name="ws root span" - ) - - async for _ in client.execute_ws(""): - pass - - mocked_start_as_current_span.assert_any_call("ws root span", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - - -@pytest.mark.asyncio -async def test_execute_ws_creates_root_span_with_custom_context( - mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument -): - client = AsyncBaseClient( - ws_url="ws://test_url", tracer="tracker", ws_root_context={"value": "test"} - ) - - async for _ in client.execute_ws(""): - pass - - mocked_start_as_current_span.assert_any_call( - "GraphQL Subscription", context={"value": "test"} - ) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("component", "GraphQL Client") - - -@pytest.mark.asyncio -async def test_execute_ws_creates_span_for_init_message( - mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): -): - client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") - - async for _ in client.execute_ws(""): - pass - - mocked_start_as_current_span.assert_any_call("connection init", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("type", "connection_init") - - -@pytest.mark.asyncio -async def test_execute_ws_creates_span_for_subscribe_message( - mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): -): - client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") - - async for _ in client.execute_ws( - "subscription Abc(a: String, b: InputB) { value }", - variables={"a": "AAA", "b": {"valueB": 21}}, - ): - pass - - mocked_start_as_current_span.assert_any_call("connection init", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("type", "connection_init") - - -@pytest.mark.parametrize( - "received_message", - [ - {"type": "next", "payload": {"data": "test_data"}}, - {"type": "complete"}, - {"type": "ping"}, - ], -) -@pytest.mark.asyncio -async def test_execute_ws_creates_span_for_received_message( - received_message, mocked_websocket, mocked_start_as_current_span -): - mocked_websocket.__aiter__.return_value.append(json.dumps(received_message)) - client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") - - async for _ in client.execute_ws(""): - pass - - mocked_start_as_current_span.assert_any_call("received message", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("type", received_message["type"]) - - -@pytest.mark.asyncio -async def test_execute_ws_creates_span_for_received_error_message( - mocked_websocket, mocked_start_as_current_span -): - mocked_websocket.__aiter__.return_value.append( - json.dumps({"type": "error", "payload": [{"message": "error_message"}]}) - ) - client = AsyncBaseClient(ws_url="ws://test_url", tracer="tracker") - - with pytest.raises(GraphQLClientGraphQLMultiError): - async for _ in client.execute_ws(""): - pass - - mocked_start_as_current_span.assert_any_call("received message", context=ANY) - with mocked_start_as_current_span.return_value as span: - span.set_attribute.assert_any_call("type", "error") diff --git a/tests/client_generators/dependencies/test_websockets_with_telemetry.py b/tests/client_generators/dependencies/test_websockets_with_telemetry.py new file mode 100644 index 00000000..032544a7 --- /dev/null +++ b/tests/client_generators/dependencies/test_websockets_with_telemetry.py @@ -0,0 +1,348 @@ +import json +from unittest.mock import ANY + +import pytest + +from ariadne_codegen.client_generators.dependencies.async_base_client_with_telemetry import ( # pylint: disable=line-too-long + AsyncBaseClientWithTelemetry, +) +from ariadne_codegen.client_generators.dependencies.exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientInvalidMessageFormat, +) + + +@pytest.fixture +def mocked_ws_connect(mocker): + return mocker.patch( + "ariadne_codegen.client_generators.dependencies." + "async_base_client_with_telemetry.ws_connect" + ) + + +@pytest.fixture +def mocked_websocket(mocked_ws_connect): + websocket = mocked_ws_connect.return_value.__aenter__.return_value + websocket.__aiter__.return_value = [ + json.dumps({"type": "connection_ack"}), + ] + return websocket + + +@pytest.mark.asyncio +async def test_execute_ws_creates_websocket_connection_with_correct_url( + mocked_ws_connect, +): + async for _ in AsyncBaseClientWithTelemetry(ws_url="ws://test_url").execute_ws(""): + pass + + assert mocked_ws_connect.called + assert "ws://test_url" in mocked_ws_connect.call_args.args + + +@pytest.mark.asyncio +async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( + mocked_ws_connect, +): + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + assert mocked_ws_connect.called + assert mocked_ws_connect.call_args.kwargs["subprotocols"] == [ + "graphql-transport-ws" + ] + + +@pytest.mark.asyncio +async def test_execute_ws_creates_websocket_connection_with_correct_origin( + mocked_ws_connect, +): + async for _ in AsyncBaseClientWithTelemetry(ws_origin="test_origin").execute_ws(""): + pass + + assert mocked_ws_connect.called + assert mocked_ws_connect.call_args.kwargs["origin"] == "test_origin" + + +@pytest.mark.asyncio +async def test_execute_ws_creates_websocket_connection_with_correct_headers( + mocked_ws_connect, +): + async for _ in AsyncBaseClientWithTelemetry( + ws_headers={"test_key": "test_value"} + ).execute_ws(""): + pass + + assert mocked_ws_connect.called + assert mocked_ws_connect.call_args.kwargs["extra_headers"] == { + "test_key": "test_value" + } + + +@pytest.mark.asyncio +async def test_execute_ws_sends_correct_init_connection_data(mocked_websocket): + async for _ in AsyncBaseClientWithTelemetry( + ws_connection_init_payload={"test_key": "test_value"} + ).execute_ws(""): + pass + + init_call, _ = mocked_websocket.send.mock_calls + sent_data = json.loads(init_call.args[0]) + assert sent_data["type"] == "connection_init" + assert sent_data["payload"] == {"test_key": "test_value"} + + +@pytest.mark.asyncio +async def test_execute_ws_sends_correct_subscribe_data(mocked_websocket): + query_str = "query testQuery($arg: String!) { test(arg: $arg) }" + variables = {"arg": "test_value"} + + async for _ in AsyncBaseClientWithTelemetry().execute_ws( + query=query_str, variables=variables + ): + pass + + _, subscribe_call = mocked_websocket.send.mock_calls + sent_data = json.loads(subscribe_call.args[0]) + assert sent_data["type"] == "subscribe" + assert sent_data["payload"] == {"query": query_str, "variables": variables} + + +@pytest.mark.asyncio +async def test_execute_ws_yields_data_for_next_message(mocked_websocket): + mocked_websocket.__aiter__.return_value.append( + json.dumps({"type": "next", "payload": {"data": "test_data"}}) + ) + + received_data = [] + async for data in AsyncBaseClientWithTelemetry().execute_ws(""): + received_data.append(data) + + assert received_data == ["test_data"] + + +@pytest.mark.asyncio +async def test_execute_ws_yields_handles_multiple_next_messages(mocked_websocket): + mocked_websocket.__aiter__.return_value.extend( + [ + json.dumps({"type": "next", "payload": {"data": "A"}}), + json.dumps({"type": "next", "payload": {"data": "B"}}), + json.dumps({"type": "next", "payload": {"data": "C"}}), + ] + ) + + received_data = [] + async for data in AsyncBaseClientWithTelemetry().execute_ws(""): + received_data.append(data) + + assert received_data == ["A", "B", "C"] + + +@pytest.mark.asyncio +async def test_execute_ws_closes_websocket_for_complete_message(mocked_websocket): + mocked_websocket.__aiter__.return_value.append(json.dumps({"type": "complete"})) + + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + assert mocked_websocket.close.called + + +@pytest.mark.asyncio +async def test_execute_ws_sends_pong_for_ping_message(mocked_websocket): + mocked_websocket.__aiter__.return_value.append(json.dumps({"type": "ping"})) + + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + pong_call = mocked_websocket.send.mock_calls[-1] + sent_data = json.loads(pong_call.args[0]) + assert sent_data["type"] == "pong" + + +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_not_json_message( + mocked_websocket, +): + mocked_websocket.__aiter__.return_value.append("not_valid_json") + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_message_without_type( + mocked_websocket, +): + mocked_websocket.__aiter__.return_value.append(json.dumps({"payload": {}})) + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_message_with_invalid_type( + mocked_websocket, +): + mocked_websocket.__aiter__.return_value.append(json.dumps({"type": "invalid_type"})) + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_next_payload_without_data( + mocked_websocket, +): + mocked_websocket.__aiter__.return_value.append( + json.dumps({"type": "next", "payload": {"not_data": "A"}}) + ) + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + +@pytest.mark.asyncio +async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type( + mocked_websocket, +): + mocked_websocket.__aiter__.return_value.append( + json.dumps({"type": "error", "payload": [{"message": "error_message"}]}) + ) + + with pytest.raises(GraphQLClientGraphQLMultiError): + async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + pass + + +@pytest.fixture +def mocked_start_as_current_span(mocker): + mocker_get_tracer = mocker.patch( + "ariadne_codegen.client_generators.dependencies." + "async_base_client_with_telemetry.get_tracer" + ) + return mocker_get_tracer.return_value.start_as_current_span + + +@pytest.mark.asyncio +async def test_execute_ws_creates_root_span( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument +): + client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("GraphQL Subscription", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_root_span_with_custom_name( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument +): + client = AsyncBaseClientWithTelemetry( + ws_url="ws://test_url", tracer="tracker", ws_root_span_name="ws root span" + ) + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("ws root span", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_root_span_with_custom_context( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument +): + client = AsyncBaseClientWithTelemetry( + ws_url="ws://test_url", tracer="tracker", ws_root_context={"value": "test"} + ) + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call( + "GraphQL Subscription", context={"value": "test"} + ) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("component", "GraphQL Client") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_init_message( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): +): + client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("connection init", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "connection_init") + + +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_subscribe_message( + mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): +): + client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws( + "subscription Abc(a: String, b: InputB) { value }", + variables={"a": "AAA", "b": {"valueB": 21}}, + ): + pass + + mocked_start_as_current_span.assert_any_call("connection init", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "connection_init") + + +@pytest.mark.parametrize( + "received_message", + [ + {"type": "next", "payload": {"data": "test_data"}}, + {"type": "complete"}, + {"type": "ping"}, + ], +) +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_received_message( + received_message, mocked_websocket, mocked_start_as_current_span +): + mocked_websocket.__aiter__.return_value.append(json.dumps(received_message)) + client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("received message", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", received_message["type"]) + + +@pytest.mark.asyncio +async def test_execute_ws_creates_span_for_received_error_message( + mocked_websocket, mocked_start_as_current_span +): + mocked_websocket.__aiter__.return_value.append( + json.dumps({"type": "error", "payload": [{"message": "error_message"}]}) + ) + client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + + with pytest.raises(GraphQLClientGraphQLMultiError): + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("received message", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "error") diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index 49efb1a2..3d33b247 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -1,17 +1,6 @@ import enum import json -from typing import ( - IO, - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast from uuid import uuid4 import httpx @@ -45,26 +34,6 @@ def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name raise NotImplementedError("Subscriptions require 'websockets' package.") -try: - from opentelemetry.trace import ( # type: ignore[attr-defined] - Context, - Span, - Tracer, - get_tracer, - set_span_in_context, - ) -except ImportError: - Context = Any # type: ignore - Span = Any # type: ignore - Tracer = Any # type: ignore - - def get_tracer(*args, **kwargs) -> Tracer: # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - def set_span_in_context(*args, **kwargs): # type: ignore - raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") - - Self = TypeVar("Self", bound="AsyncBaseClient") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -91,11 +60,6 @@ def __init__( ws_headers: Optional[Dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[Dict[str, Any]] = None, - tracer: Optional[Union[str, Tracer]] = None, - root_context: Optional[Context] = None, - root_span_name: str = "GraphQL Operation", - ws_root_context: Optional[Context] = None, - ws_root_span_name: str = "GraphQL Subscription", ) -> None: self.url = url self.headers = headers @@ -108,14 +72,6 @@ def __init__( self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload - self.tracer: Optional[Tracer] = ( - get_tracer(tracer) if isinstance(tracer, str) else tracer - ) - self.root_context = root_context - self.root_span_name = root_span_name - self.ws_root_context = ws_root_context - self.ws_root_span_name = ws_root_span_name - async def __aenter__(self: Self) -> Self: return self @@ -130,10 +86,17 @@ async def __aexit__( async def execute( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> httpx.Response: - if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + ) - return await self._execute(query=query, variables=variables) + return await self._execute_json(query=query, variables=processed_variables) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -162,30 +125,25 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: async def execute_ws( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> AsyncIterator[Dict[str, Any]]: - if self.tracer: - generator = self._execute_ws_with_telemetry( - query=query, variables=variables - ) - else: - generator = self._execute_ws(query=query, variables=variables) - - async for message in generator: - yield message - - async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart( + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + origin=self.ws_origin, + extra_headers=self.ws_headers, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, query=query, - variables=processed_variables, - files=files, - files_map=files_map, + variables=variables, ) - return await self._execute_json(query=query, variables=processed_variables) + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -285,29 +243,6 @@ async def _execute_json( headers={"Content-Type": "application/json"}, ) - async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init(websocket) - await self._send_subscribe( - websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message(message, websocket) - if data: - yield data - async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: payload: Dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value @@ -365,185 +300,3 @@ async def _handle_ws_message( ) return None - - async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - self.root_span_name, context=self.root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - - processed_variables, files, files_map = self._process_variables(variables) - - if files and files_map: - return await self._execute_multipart_with_telemetry( - root_span=root_span, - query=query, - variables=processed_variables, - files=files, - files_map=files_map, - ) - - return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables - ) - - async def _execute_multipart_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "multipart request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - serialized_map = json.dumps(files_map, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - span.set_attribute("map", serialized_map) - return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map - ) - - async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], - ) -> httpx.Response: - with self.tracer.start_as_current_span( # type: ignore - "json request", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - serialized_variables = json.dumps(variables, default=to_jsonable_python) - - span.set_attribute("query", query) - span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) - - async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> AsyncIterator[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - self.ws_root_span_name, context=self.ws_root_context - ) as root_span: - root_span.set_attribute("component", "GraphQL Client") - operation_id = str(uuid4()) - async with ws_connect( - self.ws_url, - subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, - ) as websocket: - await self._send_connection_init_with_telemetry( - root_span=root_span, - websocket=websocket, - ) - await self._send_subscribe_with_telemetry( - root_span=root_span, - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async for message in websocket: - data = await self._handle_ws_message_with_telemetry( - root_span=root_span, message=message, websocket=websocket - ) - if data: - yield data - - async def _send_connection_init_with_telemetry( - self, root_span: Span, websocket: WebSocketClientProtocol - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "connection init", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute( - "type", GraphQLTransportWSMessageType.CONNECTION_INIT.value - ) - if self.ws_connection_init_payload: - span.set_attribute( - "payload", json.dumps(self.ws_connection_init_payload) - ) - - await self._send_connection_init(websocket=websocket) - - async def _send_subscribe_with_telemetry( - self, - root_span: Span, - websocket: WebSocketClientProtocol, - operation_id: str, - query: str, - variables: Optional[Dict[str, Any]] = None, - ) -> None: - with self.tracer.start_as_current_span( # type: ignore - "subscribe", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - span.set_attribute("id", operation_id) - span.set_attribute("type", GraphQLTransportWSMessageType.SUBSCRIBE.value) - span.set_attribute("query", query) - if variables: - span.set_attribute( - "variables", - json.dumps(self._convert_dict_to_json_serializable(variables)), - ) - - await self._send_subscribe( - websocket=websocket, - operation_id=operation_id, - query=query, - variables=variables, - ) - - async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol - ) -> Optional[Dict[str, Any]]: - with self.tracer.start_as_current_span( # type: ignore - "received message", context=set_span_in_context(root_span) - ) as span: - span.set_attribute("component", "GraphQL Client") - - try: - message_dict = json.loads(message) - except json.JSONDecodeError as exc: - raise GraphQLClientInvalidMessageFormat(message=message) from exc - - type_ = message_dict.get("type") - payload = message_dict.get("payload", {}) - - span.set_attribute("type", type_) - - if not type_ or type_ not in { - t.value for t in GraphQLTransportWSMessageType - }: - raise GraphQLClientInvalidMessageFormat(message=message) - - if type_ == GraphQLTransportWSMessageType.NEXT: - if "data" not in payload: - raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) - - if type_ == GraphQLTransportWSMessageType.COMPLETE: - await websocket.close() - elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) - elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) - - return None diff --git a/tests/test_settings.py b/tests/test_settings.py index 6b024109..d9b8a0cb 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -4,8 +4,12 @@ import pytest -import ariadne_codegen.client_generators.dependencies.async_base_client -import ariadne_codegen.client_generators.dependencies.base_client +from ariadne_codegen.client_generators.dependencies import ( + async_base_client, + async_base_client_with_telemetry, + base_client, + base_client_with_telemetry, +) from ariadne_codegen.config import ClientSettings, GraphQLSchemaSettings from ariadne_codegen.exceptions import InvalidConfiguration @@ -76,22 +80,21 @@ class BaseClient: @pytest.mark.parametrize( - "async_client,expected_name,file_path", + "async_client, telemetry_client, expected_name, expected_path", [ ( True, - "AsyncBaseClient", - ariadne_codegen.client_generators.dependencies.async_base_client.__file__, - ), - ( - False, - "BaseClient", - ariadne_codegen.client_generators.dependencies.base_client.__file__, + True, + "AsyncBaseClientWithTelemetry", + async_base_client_with_telemetry.__file__, ), + (True, False, "AsyncBaseClient", async_base_client.__file__), + (False, True, "BaseClientWithTelemetry", base_client_with_telemetry.__file__), + (False, False, "BaseClient", base_client.__file__), ], ) def test_client_settings_sets_correct_default_values_for_base_client_name_and_path( - tmp_path, async_client, expected_name, file_path + tmp_path, async_client, telemetry_client, expected_name, expected_path ): schema_path = tmp_path / "schema.graphql" schema_path.touch() @@ -99,11 +102,14 @@ def test_client_settings_sets_correct_default_values_for_base_client_name_and_pa queries_path.touch() settings = ClientSettings( - schema_path=schema_path, queries_path=queries_path, async_client=async_client + schema_path=schema_path, + queries_path=queries_path, + async_client=async_client, + telemetry_client=telemetry_client, ) assert settings.base_client_name == expected_name - assert settings.base_client_file_path == Path(file_path).as_posix() + assert settings.base_client_file_path == Path(expected_path).as_posix() def test_client_settings_without_schema_path_with_remote_schema_url_is_valid(tmp_path): From 760b76769bdcdf56811e358bd68659ce6ff06380 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 11 Oct 2023 12:58:08 +0200 Subject: [PATCH 8/9] Change telemetry references to open telemetry. Fix wording --- .github/workflows/tests.yml | 2 +- CHANGELOG.md | 2 +- README.md | 10 ++-- .../client_generators/constants.py | 12 ++-- ...py => async_base_client_open_telemetry.py} | 4 +- ...metry.py => base_client_open_telemetry.py} | 4 +- ariadne_codegen/client_generators/package.py | 8 +-- ariadne_codegen/settings.py | 22 ++++---- pyproject.toml | 2 +- ... test_async_base_client_open_telemetry.py} | 56 +++++++++---------- ....py => test_base_client_open_telemetry.py} | 56 +++++++++---------- ...y.py => test_websockets_open_telemetry.py} | 52 ++++++++--------- tests/test_settings.py | 16 +++--- 13 files changed, 124 insertions(+), 122 deletions(-) rename ariadne_codegen/client_generators/dependencies/{async_base_client_with_telemetry.py => async_base_client_open_telemetry.py} (99%) rename ariadne_codegen/client_generators/dependencies/{base_client_with_telemetry.py => base_client_open_telemetry.py} (99%) rename tests/client_generators/dependencies/{test_async_base_client_with_telemetry.py => test_async_base_client_open_telemetry.py} (91%) rename tests/client_generators/dependencies/{test_base_client_with_telemetry.py => test_base_client_open_telemetry.py} (91%) rename tests/client_generators/dependencies/{test_websockets_with_telemetry.py => test_websockets_open_telemetry.py} (86%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 37ce6f88..d106f322 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,7 +25,7 @@ jobs: run: | python -m pip install --upgrade pip pip install wheel - pip install -e .[subscriptions,telemetry,dev] + pip install -e .[subscriptions,opentelemetry,dev] - name: Pytest run: | pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index a7253bc6..5a67a582 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ - Fixed parsing of unions and interfaces to always add `__typename` to generated result models. - Added escaping of enum values which are Python keywords by appending `_` to them. - Fixed `enums_module_name` option not being passed to generators. -- Added additional base clients with opentelemetry support. Added `telemetry_client` config option. +- Added additional base clients supporting the Open Telemetry tracing. Added `opentelemetry_client` config option. ## 0.9.0 (2023-09-11) diff --git a/README.md b/README.md index b9ff8586..f1417539 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ Optional settings: - `include_comments` (defaults to `"stable"`) - option which sets content of comments included at the top of every generated file. Valid choices are: `"none"` (no comments), `"timestamp"` (comment with generation timestamp), `"stable"` (comment contains a message that this is a generated file) - `convert_to_snake_case` (defaults to `true`) - a flag that specifies whether to convert fields and arguments names to snake case - `async_client` (defaults to `true`) - default generated client is `async`, change this to option `false` to generate synchronous client instead -- `telemetry_client` (defaults to `false`) - default base clients doesn't support opentelemetry, change this option to `true` to use base client with opentelemetry +- `opentelemetry_client` (defaults to `false`) - default base clients don't support any performance tracing. Change this option to `true` to use the base client with Open Telemetry support. - `files_to_include` (defaults to `[]`) - list of files which will be copied into generated package - `plugins` (defaults to `[]`) - list of plugins to use during generation @@ -143,16 +143,16 @@ type = "Upload" ``` -### Opentelemetry +### Open Telemetry -When config option `telemetry_client` is set to `true` then default, included base clients support opt-in telemetry options. By default, it's disabled, but when the `opentelemetry-api` package is installed and the `tracer` argument is provided then the client will create spans with data about performed requests. +When config option `opentelemetry_client` is set to `true` then default, included base client is replaced with one that implements the opt-in Open Telemetry support. By default this support does nothing but when the `opentelemetry-api` package is installed and the `tracer` argument is provided then the client will create spans with data about performed requests. -Telemetry arguments handled by `BaseClientWithTelemetry`: +Telemetry arguments handled by `BaseClientOpenTelemetry`: - `tracer`: `Optional[Union[str, Tracer]] = None` - tracer object or name which will be passed to the `get_tracer` method - `root_context`: `Optional[Context] = None` - optional context added to root span - `root_span_name`: `str = "GraphQL Operation"` - name of root span -`AsyncBaseClientWithTelemetry` supports all arguments which `BaseClientWithTelemetry` does, but also exposes additional arguments regarding websockets: +`AsyncBaseClientOpenTelemetry` supports all arguments which `BaseClientOpenTelemetry` does, but also exposes additional arguments regarding websockets: - `ws_root_context`: `Optional[Context] = None` - optional context added to root span for websocket connection - `ws_root_span_name`: `str = "GraphQL Subscription"` - name of root span for websocket connection diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index 324588d5..94e8ef75 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -74,18 +74,18 @@ ) DEFAULT_ASYNC_BASE_CLIENT_NAME = "AsyncBaseClient" -DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH = ( - Path(__file__).parent / "dependencies" / "async_base_client_with_telemetry.py" +DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH = ( + Path(__file__).parent / "dependencies" / "async_base_client_open_telemetry.py" ) -DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_NAME = "AsyncBaseClientWithTelemetry" +DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME = "AsyncBaseClientOpenTelemetry" DEFAULT_BASE_CLIENT_PATH = Path(__file__).parent / "dependencies" / "base_client.py" DEFAULT_BASE_CLIENT_NAME = "BaseClient" -DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH = ( - Path(__file__).parent / "dependencies" / "base_client_with_telemetry.py" +DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_PATH = ( + Path(__file__).parent / "dependencies" / "base_client_open_telemetry.py" ) -DEFAULT_BASE_CLIENT_WITH_TELEMETRY_NAME = "BaseClientWithTelemetry" +DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_NAME = "BaseClientOpenTelemetry" GRAPHQL_CLIENT_EXCEPTIONS_NAMES = [ diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_with_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py similarity index 99% rename from ariadne_codegen/client_generators/dependencies/async_base_client_with_telemetry.py rename to ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py index ce41d597..40ac419c 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client_with_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py @@ -65,7 +65,7 @@ def set_span_in_context(*args, **kwargs): # type: ignore raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") -Self = TypeVar("Self", bound="AsyncBaseClientWithTelemetry") +Self = TypeVar("Self", bound="AsyncBaseClientOpenTelemetry") GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" @@ -81,7 +81,7 @@ class GraphQLTransportWSMessageType(str, enum.Enum): COMPLETE = "complete" -class AsyncBaseClientWithTelemetry: +class AsyncBaseClientOpenTelemetry: def __init__( self, url: str = "", diff --git a/ariadne_codegen/client_generators/dependencies/base_client_with_telemetry.py b/ariadne_codegen/client_generators/dependencies/base_client_open_telemetry.py similarity index 99% rename from ariadne_codegen/client_generators/dependencies/base_client_with_telemetry.py rename to ariadne_codegen/client_generators/dependencies/base_client_open_telemetry.py index 996ec374..033629f0 100644 --- a/ariadne_codegen/client_generators/dependencies/base_client_with_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/base_client_open_telemetry.py @@ -32,10 +32,10 @@ def set_span_in_context(*args, **kwargs): # type: ignore raise NotImplementedError("Telemetry requires 'opentelemetry-api' package.") -Self = TypeVar("Self", bound="BaseClientWithTelemetry") +Self = TypeVar("Self", bound="BaseClientOpenTelemetry") -class BaseClientWithTelemetry: +class BaseClientOpenTelemetry: def __init__( self, url: str = "", diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index 466e5d58..47c349ec 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -16,10 +16,10 @@ BASE_MODEL_CLASS_NAME, BASE_MODEL_FILE_PATH, BASE_MODEL_IMPORT, + DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH, DEFAULT_ASYNC_BASE_CLIENT_PATH, - DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, + DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_PATH, DEFAULT_BASE_CLIENT_PATH, - DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, EXCEPTIONS_FILE_PATH, GRAPHQL_CLIENT_EXCEPTIONS_NAMES, UNSET_IMPORT, @@ -174,8 +174,8 @@ def _include_exceptions(self): if self.base_client_file_path in ( DEFAULT_ASYNC_BASE_CLIENT_PATH, DEFAULT_BASE_CLIENT_PATH, - DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, - DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, + DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH, + DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_PATH, ): self.files_to_include.append(EXCEPTIONS_FILE_PATH) self.init_generator.add_import( diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 2c10e1ec..64fef2c5 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -8,13 +8,13 @@ from .client_generators.constants import ( DEFAULT_ASYNC_BASE_CLIENT_NAME, + DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME, + DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH, DEFAULT_ASYNC_BASE_CLIENT_PATH, - DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_NAME, - DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, DEFAULT_BASE_CLIENT_NAME, + DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_NAME, + DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_PATH, DEFAULT_BASE_CLIENT_PATH, - DEFAULT_BASE_CLIENT_WITH_TELEMETRY_NAME, - DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, ) from .client_generators.scalars import ScalarData from .exceptions import InvalidConfiguration @@ -66,7 +66,7 @@ class ClientSettings(BaseSettings): include_comments: CommentsStrategy = field(default=CommentsStrategy.STABLE) convert_to_snake_case: bool = True async_client: bool = True - telemetry_client: bool = False + opentelemetry_client: bool = False files_to_include: List[str] = field(default_factory=list) scalars: Dict[str, ScalarData] = field(default_factory=dict) @@ -112,16 +112,16 @@ def __post_init__(self): def _set_default_base_client_data(self): default_clients_map = { (True, True): ( - DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_PATH, - DEFAULT_ASYNC_BASE_CLIENT_WITH_TELEMETRY_NAME, + DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH, + DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME, ), (True, False): ( DEFAULT_ASYNC_BASE_CLIENT_PATH, DEFAULT_ASYNC_BASE_CLIENT_NAME, ), (False, True): ( - DEFAULT_BASE_CLIENT_WITH_TELEMETRY_PATH, - DEFAULT_BASE_CLIENT_WITH_TELEMETRY_NAME, + DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_PATH, + DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_NAME, ), (False, False): ( DEFAULT_BASE_CLIENT_PATH, @@ -129,7 +129,9 @@ def _set_default_base_client_data(self): ), } if not self.base_client_name and not self.base_client_file_path: - path, name = default_clients_map[(self.async_client, self.telemetry_client)] + path, name = default_clients_map[ + (self.async_client, self.opentelemetry_client) + ] self.base_client_name = name self.base_client_file_path = path.as_posix() diff --git a/pyproject.toml b/pyproject.toml index 030c4ac0..d603006b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dev = [ "requests-toolbelt", ] subscriptions = ["websockets~=11.0"] -telemetry = ["opentelemetry-api"] +opentelemetry = ["opentelemetry-api"] [project.scripts] ariadne-codegen = "ariadne_codegen.main:main" diff --git a/tests/client_generators/dependencies/test_async_base_client_with_telemetry.py b/tests/client_generators/dependencies/test_async_base_client_open_telemetry.py similarity index 91% rename from tests/client_generators/dependencies/test_async_base_client_with_telemetry.py rename to tests/client_generators/dependencies/test_async_base_client_open_telemetry.py index 1f451169..66361525 100644 --- a/tests/client_generators/dependencies/test_async_base_client_with_telemetry.py +++ b/tests/client_generators/dependencies/test_async_base_client_open_telemetry.py @@ -6,8 +6,8 @@ import httpx import pytest -from ariadne_codegen.client_generators.dependencies.async_base_client_with_telemetry import ( # pylint: disable=line-too-long - AsyncBaseClientWithTelemetry, +from ariadne_codegen.client_generators.dependencies.async_base_client_open_telemetry import ( # pylint: disable=line-too-long + AsyncBaseClientOpenTelemetry, ) from ariadne_codegen.client_generators.dependencies.base_model import UNSET, BaseModel from ariadne_codegen.client_generators.dependencies.exceptions import ( @@ -23,7 +23,7 @@ async def test_execute_sends_post_to_correct_url_with_correct_payload(httpx_mock): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url/endpoint") + client = AsyncBaseClientOpenTelemetry(url="http://base_url/endpoint") query_str = """ query Abc($v: String!) { abc(v: $v) { @@ -49,7 +49,7 @@ class TestModel2(BaseModel): nested: TestModel1 httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($v1: TestModel1!, $v2: TestModel2) { abc(v1: $v1, v2: $v2){ @@ -76,7 +76,7 @@ class TestModel1(BaseModel): a: int httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($v1: [[TestModel1!]!]!) { abc(v1: $v1){ @@ -104,7 +104,7 @@ class TestModel1(BaseModel): @pytest.mark.asyncio async def test_execute_sends_payload_without_unset_arguments(httpx_mock): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($arg1: TestInputA, $arg2: String, $arg3: Float, $arg4: Int!) { abc(arg1: $arg1, arg2: $arg2, arg3: $arg3, arg4: $arg4){ @@ -139,7 +139,7 @@ class TestInputA(BaseModel): input_b3: Optional[TestInputB] = None httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($arg: TestInputB) { abc(arg: $arg){ @@ -176,7 +176,7 @@ async def test_execute_sends_payload_with_serialized_datetime_without_exception( httpx_mock, ): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") query_str = "query Abc($arg: DATETIME) { abc }" arg_value = datetime(2023, 12, 31, 10, 15) @@ -190,7 +190,7 @@ async def test_execute_sends_payload_with_serialized_datetime_without_exception( @pytest.mark.asyncio async def test_execute_sends_request_with_correct_content_type(httpx_mock): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") await client.execute("query Abc { abc }", {}) @@ -203,7 +203,7 @@ async def test_execute_sends_request_with_extra_headers_and_correct_content_type httpx_mock, ): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( url="http://base_url", headers={"h_key": "h_value"} ) @@ -220,7 +220,7 @@ async def test_execute_sends_file_with_multipart_form_data_content_type( ): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") await client.execute( "query Abc($file: Upload!) { abc(file: $file) }", {"file": txt_file} ) @@ -234,7 +234,7 @@ async def test_execute_sends_file_as_multipart_request(httpx_mock, txt_file): httpx_mock.add_response() query_str = "query Abc($file: Upload!) { abc(file: $file) }" - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") await client.execute(query_str, {"file": txt_file}) request = httpx_mock.get_request() @@ -261,7 +261,7 @@ async def test_execute_sends_file_from_memory(httpx_mock, in_memory_txt_file): httpx_mock.add_response() query_str = "query Abc($file: Upload!) { abc(file: $file) }" - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") await client.execute(query_str, {"file": in_memory_txt_file}) request = httpx_mock.get_request() @@ -288,7 +288,7 @@ async def test_execute_sends_multiple_files(httpx_mock, txt_file, png_file): httpx_mock.add_response() query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") await client.execute(query_str, {"files": [txt_file, png_file]}) request = httpx_mock.get_request() @@ -326,7 +326,7 @@ class InputType(BaseModel): httpx_mock.add_response() query_str = "query Abc($input: InputType!) { abc(input: $input) }" - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") await client.execute(query_str, {"input": InputType(file_=txt_file)}) request = httpx_mock.get_request() @@ -356,7 +356,7 @@ async def test_execute_sends_each_file_only_once(httpx_mock, txt_file): httpx_mock.add_response() query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" - client = AsyncBaseClientWithTelemetry(url="http://base_url") + client = AsyncBaseClientOpenTelemetry(url="http://base_url") await client.execute(query_str, {"files": [txt_file, txt_file]}) request = httpx_mock.get_request() @@ -393,7 +393,7 @@ async def test_execute_sends_each_file_only_once(httpx_mock, txt_file): def test_get_data_raises_graphql_client_http_error( mocker, status_code, response_content ): - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( url="base_url", http_client=mocker.MagicMock() ) response = httpx.Response( @@ -410,7 +410,7 @@ def test_get_data_raises_graphql_client_http_error( def test_get_data_raises_graphql_client_invalid_response_error( mocker, response_content ): - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( url="base_url", http_client=mocker.MagicMock() ) response = httpx.Response(status_code=200, content=json.dumps(response_content)) @@ -451,7 +451,7 @@ def test_get_data_raises_graphql_client_invalid_response_error( ], ) def test_get_data_raises_graphql_client_graphql_multi_error(mocker, response_content): - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( url="base_url", http_client=mocker.MagicMock() ) @@ -466,7 +466,7 @@ def test_get_data_raises_graphql_client_graphql_multi_error(mocker, response_con [{"errors": [], "data": {}}, {"errors": None, "data": {}}, {"data": {}}], ) def test_get_data_doesnt_raise_exception(mocker, response_content): - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( url="base_url", http_client=mocker.MagicMock() ) @@ -480,7 +480,7 @@ def test_get_data_doesnt_raise_exception(mocker, response_content): @pytest.mark.asyncio async def test_base_client_used_as_context_manager_closes_http_client(mocker): fake_client = mocker.AsyncMock() - async with AsyncBaseClientWithTelemetry( + async with AsyncBaseClientOpenTelemetry( url="base_url", http_client=fake_client ) as base_client: await base_client.execute("") @@ -492,7 +492,7 @@ async def test_base_client_used_as_context_manager_closes_http_client(mocker): def mocker_get_tracer(mocker): return mocker.patch( "ariadne_codegen.client_generators.dependencies." - "async_base_client_with_telemetry.get_tracer" + "async_base_client_open_telemetry.get_tracer" ) @@ -505,7 +505,7 @@ def mocked_start_as_current_span(mocker_get_tracer): async def test_async_base_client_with_given_tracker_str_uses_global_tracker( mocker_get_tracer, ): - AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker name") + AsyncBaseClientOpenTelemetry(url="http://base_url", tracer="tracker name") assert mocker_get_tracer.call_count == 1 @@ -513,7 +513,7 @@ async def test_async_base_client_with_given_tracker_str_uses_global_tracker( @pytest.mark.asyncio async def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(url="http://base_url", tracer="tracker") await client.execute("query GetHello { hello }") @@ -527,7 +527,7 @@ async def test_execute_creates_root_span_with_custom_name( httpx_mock, mocked_start_as_current_span ): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( url="http://base_url", tracer="tracker", root_span_name="root_span" ) @@ -541,7 +541,7 @@ async def test_execute_creates_root_span_with_custom_context( httpx_mock, mocked_start_as_current_span ): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( url="http://base_url", tracer="tracker", root_context={"abc": 123} ) @@ -557,7 +557,7 @@ async def test_execute_creates_span_for_json_http_request( httpx_mock, mocked_start_as_current_span ): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(url="http://base_url", tracer="tracker") await client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) @@ -575,7 +575,7 @@ async def test_execute_creates_span_for_multipart_request( httpx_mock, txt_file, mocked_start_as_current_span ): httpx_mock.add_response() - client = AsyncBaseClientWithTelemetry(url="http://base_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(url="http://base_url", tracer="tracker") await client.execute( "query Abc($file: Upload!) { abc(file: $file) }", diff --git a/tests/client_generators/dependencies/test_base_client_with_telemetry.py b/tests/client_generators/dependencies/test_base_client_open_telemetry.py similarity index 91% rename from tests/client_generators/dependencies/test_base_client_with_telemetry.py rename to tests/client_generators/dependencies/test_base_client_open_telemetry.py index 5dc1f127..2a855055 100644 --- a/tests/client_generators/dependencies/test_base_client_with_telemetry.py +++ b/tests/client_generators/dependencies/test_base_client_open_telemetry.py @@ -6,8 +6,8 @@ import httpx import pytest -from ariadne_codegen.client_generators.dependencies.base_client_with_telemetry import ( - BaseClientWithTelemetry, +from ariadne_codegen.client_generators.dependencies.base_client_open_telemetry import ( + BaseClientOpenTelemetry, ) from ariadne_codegen.client_generators.dependencies.base_model import UNSET, BaseModel from ariadne_codegen.client_generators.dependencies.exceptions import ( @@ -21,7 +21,7 @@ def test_execute_sends_post_to_correct_url_with_correct_payload(httpx_mock): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url/endpoint") + client = BaseClientOpenTelemetry(url="http://base_url/endpoint") query_str = """ query Abc($v: String!) { abc(v: $v) { @@ -46,7 +46,7 @@ class TestModel2(BaseModel): nested: TestModel1 httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($v1: TestModel1!, $v2: TestModel2) { abc(v1: $v1, v2: $v2){ @@ -72,7 +72,7 @@ class TestModel1(BaseModel): a: int httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($v1: [[TestModel1!]!]!) { abc(v1: $v1){ @@ -99,7 +99,7 @@ class TestModel1(BaseModel): def test_execute_sends_payload_without_unset_arguments(httpx_mock): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($arg1: TestInputA, $arg2: String, $arg3: Float, $arg4: Int!) { abc(arg1: $arg1, arg2: $arg2, arg3: $arg3, arg4: $arg4){ @@ -131,7 +131,7 @@ class TestInputA(BaseModel): input_b3: Optional[TestInputB] = None httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") query_str = """ query Abc($arg: TestInputB) { abc(arg: $arg){ @@ -165,7 +165,7 @@ class TestInputA(BaseModel): def test_execute_sends_payload_with_serialized_datetime_without_exception(httpx_mock): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") query_str = "query Abc($arg: DATETIME) { abc }" arg_value = datetime(2023, 12, 31, 10, 15) @@ -178,7 +178,7 @@ def test_execute_sends_payload_with_serialized_datetime_without_exception(httpx_ def test_execute_sends_request_with_correct_content_type(httpx_mock): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") client.execute("query Abc { abc }", {}) @@ -188,7 +188,7 @@ def test_execute_sends_request_with_correct_content_type(httpx_mock): def test_execute_sends_request_with_extra_headers_and_correct_content_type(httpx_mock): httpx_mock.add_response() - client = BaseClientWithTelemetry( + client = BaseClientOpenTelemetry( url="http://base_url", headers={"h_key": "h_value"} ) @@ -202,7 +202,7 @@ def test_execute_sends_request_with_extra_headers_and_correct_content_type(httpx def test_execute_sends_file_with_multipart_form_data_content_type(httpx_mock, txt_file): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") client.execute("query Abc($file: Upload!) { abc(file: $file) }", {"file": txt_file}) request = httpx_mock.get_request() @@ -213,7 +213,7 @@ def test_execute_sends_file_as_multipart_request(httpx_mock, txt_file): httpx_mock.add_response() query_str = "query Abc($file: Upload!) { abc(file: $file) }" - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") client.execute(query_str, {"file": txt_file}) request = httpx_mock.get_request() @@ -239,7 +239,7 @@ def test_execute_sends_file_from_memory(httpx_mock, in_memory_txt_file): httpx_mock.add_response() query_str = "query Abc($file: Upload!) { abc(file: $file) }" - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") client.execute(query_str, {"file": in_memory_txt_file}) request = httpx_mock.get_request() @@ -265,7 +265,7 @@ def test_execute_sends_multiple_files(httpx_mock, txt_file, png_file): httpx_mock.add_response() query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") client.execute(query_str, {"files": [txt_file, png_file]}) request = httpx_mock.get_request() @@ -302,7 +302,7 @@ class InputType(BaseModel): httpx_mock.add_response() query_str = "query Abc($input: InputType!) { abc(input: $input) }" - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") client.execute(query_str, {"input": InputType(file_=txt_file)}) request = httpx_mock.get_request() @@ -331,7 +331,7 @@ def test_execute_sends_each_file_only_once(httpx_mock, txt_file): httpx_mock.add_response() query_str = "query Abc($files: [Upload!]!) { abc(files: $files) }" - client = BaseClientWithTelemetry(url="http://base_url") + client = BaseClientOpenTelemetry(url="http://base_url") client.execute(query_str, {"files": [txt_file, txt_file]}) request = httpx_mock.get_request() @@ -368,7 +368,7 @@ def test_execute_sends_each_file_only_once(httpx_mock, txt_file): def test_get_data_raises_graphql_client_http_error( mocker, status_code, response_content ): - client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + client = BaseClientOpenTelemetry(url="base_url", http_client=mocker.MagicMock()) response = httpx.Response( status_code=status_code, content=json.dumps(response_content) ) @@ -383,7 +383,7 @@ def test_get_data_raises_graphql_client_http_error( def test_get_data_raises_graphql_client_invalid_response_error( mocker, response_content ): - client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + client = BaseClientOpenTelemetry(url="base_url", http_client=mocker.MagicMock()) response = httpx.Response(status_code=200, content=json.dumps(response_content)) with pytest.raises(GraphQlClientInvalidResponseError) as exc: @@ -422,7 +422,7 @@ def test_get_data_raises_graphql_client_invalid_response_error( ], ) def test_get_data_raises_graphql_client_graphql_multi_error(mocker, response_content): - client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + client = BaseClientOpenTelemetry(url="base_url", http_client=mocker.MagicMock()) with pytest.raises(GraphQLClientGraphQLMultiError): client.get_data( @@ -435,7 +435,7 @@ def test_get_data_raises_graphql_client_graphql_multi_error(mocker, response_con [{"errors": [], "data": {}}, {"errors": None, "data": {}}, {"data": {}}], ) def test_get_data_doesnt_raise_exception(mocker, response_content): - client = BaseClientWithTelemetry(url="base_url", http_client=mocker.MagicMock()) + client = BaseClientOpenTelemetry(url="base_url", http_client=mocker.MagicMock()) data = client.get_data( httpx.Response(status_code=200, content=json.dumps(response_content)) @@ -446,7 +446,7 @@ def test_get_data_doesnt_raise_exception(mocker, response_content): def test_base_client_used_as_context_manager_closes_http_client(mocker): fake_client = mocker.MagicMock() - with BaseClientWithTelemetry( + with BaseClientOpenTelemetry( url="base_url", http_client=fake_client ) as base_client: base_client.execute("") @@ -458,7 +458,7 @@ def test_base_client_used_as_context_manager_closes_http_client(mocker): def mocker_get_tracer(mocker): return mocker.patch( "ariadne_codegen.client_generators.dependencies." - "base_client_with_telemetry.get_tracer" + "base_client_open_telemetry.get_tracer" ) @@ -468,14 +468,14 @@ def mocked_start_as_current_span(mocker_get_tracer): def test_base_client_with_given_tracker_str_uses_global_tracker(mocker_get_tracer): - BaseClientWithTelemetry(url="http://base_url", tracer="tracker name") + BaseClientOpenTelemetry(url="http://base_url", tracer="tracker name") assert mocker_get_tracer.call_count == 1 def test_execute_creates_root_span(httpx_mock, mocked_start_as_current_span): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url", tracer="tracker") + client = BaseClientOpenTelemetry(url="http://base_url", tracer="tracker") client.execute("query GetHello { hello }") @@ -488,7 +488,7 @@ def test_execute_creates_root_span_with_custom_name( httpx_mock, mocked_start_as_current_span ): httpx_mock.add_response() - client = BaseClientWithTelemetry( + client = BaseClientOpenTelemetry( url="http://base_url", tracer="tracker", root_span_name="root_span" ) @@ -501,7 +501,7 @@ def test_execute_creates_root_span_with_custom_context( httpx_mock, mocked_start_as_current_span ): httpx_mock.add_response() - client = BaseClientWithTelemetry( + client = BaseClientOpenTelemetry( url="http://base_url", tracer="tracker", root_context={"abc": 123} ) @@ -516,7 +516,7 @@ def test_execute_creates_span_for_json_http_request( httpx_mock, mocked_start_as_current_span ): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url", tracer="tracker") + client = BaseClientOpenTelemetry(url="http://base_url", tracer="tracker") client.execute("query GetHello { hello }", variables={"a": 1, "b": {"bb": 2}}) @@ -533,7 +533,7 @@ def test_execute_creates_span_for_multipart_request( httpx_mock, txt_file, mocked_start_as_current_span ): httpx_mock.add_response() - client = BaseClientWithTelemetry(url="http://base_url", tracer="tracker") + client = BaseClientOpenTelemetry(url="http://base_url", tracer="tracker") client.execute( "query Abc($file: Upload!) { abc(file: $file) }", diff --git a/tests/client_generators/dependencies/test_websockets_with_telemetry.py b/tests/client_generators/dependencies/test_websockets_open_telemetry.py similarity index 86% rename from tests/client_generators/dependencies/test_websockets_with_telemetry.py rename to tests/client_generators/dependencies/test_websockets_open_telemetry.py index 032544a7..09616e5e 100644 --- a/tests/client_generators/dependencies/test_websockets_with_telemetry.py +++ b/tests/client_generators/dependencies/test_websockets_open_telemetry.py @@ -3,8 +3,8 @@ import pytest -from ariadne_codegen.client_generators.dependencies.async_base_client_with_telemetry import ( # pylint: disable=line-too-long - AsyncBaseClientWithTelemetry, +from ariadne_codegen.client_generators.dependencies.async_base_client_open_telemetry import ( # pylint: disable=line-too-long + AsyncBaseClientOpenTelemetry, ) from ariadne_codegen.client_generators.dependencies.exceptions import ( GraphQLClientGraphQLMultiError, @@ -16,7 +16,7 @@ def mocked_ws_connect(mocker): return mocker.patch( "ariadne_codegen.client_generators.dependencies." - "async_base_client_with_telemetry.ws_connect" + "async_base_client_open_telemetry.ws_connect" ) @@ -33,7 +33,7 @@ def mocked_websocket(mocked_ws_connect): async def test_execute_ws_creates_websocket_connection_with_correct_url( mocked_ws_connect, ): - async for _ in AsyncBaseClientWithTelemetry(ws_url="ws://test_url").execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry(ws_url="ws://test_url").execute_ws(""): pass assert mocked_ws_connect.called @@ -44,7 +44,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_url( async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( mocked_ws_connect, ): - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass assert mocked_ws_connect.called @@ -57,7 +57,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( async def test_execute_ws_creates_websocket_connection_with_correct_origin( mocked_ws_connect, ): - async for _ in AsyncBaseClientWithTelemetry(ws_origin="test_origin").execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry(ws_origin="test_origin").execute_ws(""): pass assert mocked_ws_connect.called @@ -68,7 +68,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_origin( async def test_execute_ws_creates_websocket_connection_with_correct_headers( mocked_ws_connect, ): - async for _ in AsyncBaseClientWithTelemetry( + async for _ in AsyncBaseClientOpenTelemetry( ws_headers={"test_key": "test_value"} ).execute_ws(""): pass @@ -81,7 +81,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_headers( @pytest.mark.asyncio async def test_execute_ws_sends_correct_init_connection_data(mocked_websocket): - async for _ in AsyncBaseClientWithTelemetry( + async for _ in AsyncBaseClientOpenTelemetry( ws_connection_init_payload={"test_key": "test_value"} ).execute_ws(""): pass @@ -97,7 +97,7 @@ async def test_execute_ws_sends_correct_subscribe_data(mocked_websocket): query_str = "query testQuery($arg: String!) { test(arg: $arg) }" variables = {"arg": "test_value"} - async for _ in AsyncBaseClientWithTelemetry().execute_ws( + async for _ in AsyncBaseClientOpenTelemetry().execute_ws( query=query_str, variables=variables ): pass @@ -115,7 +115,7 @@ async def test_execute_ws_yields_data_for_next_message(mocked_websocket): ) received_data = [] - async for data in AsyncBaseClientWithTelemetry().execute_ws(""): + async for data in AsyncBaseClientOpenTelemetry().execute_ws(""): received_data.append(data) assert received_data == ["test_data"] @@ -132,7 +132,7 @@ async def test_execute_ws_yields_handles_multiple_next_messages(mocked_websocket ) received_data = [] - async for data in AsyncBaseClientWithTelemetry().execute_ws(""): + async for data in AsyncBaseClientOpenTelemetry().execute_ws(""): received_data.append(data) assert received_data == ["A", "B", "C"] @@ -142,7 +142,7 @@ async def test_execute_ws_yields_handles_multiple_next_messages(mocked_websocket async def test_execute_ws_closes_websocket_for_complete_message(mocked_websocket): mocked_websocket.__aiter__.return_value.append(json.dumps({"type": "complete"})) - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass assert mocked_websocket.close.called @@ -152,7 +152,7 @@ async def test_execute_ws_closes_websocket_for_complete_message(mocked_websocket async def test_execute_ws_sends_pong_for_ping_message(mocked_websocket): mocked_websocket.__aiter__.return_value.append(json.dumps({"type": "ping"})) - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass pong_call = mocked_websocket.send.mock_calls[-1] @@ -167,7 +167,7 @@ async def test_execute_ws_raises_invalid_message_format_for_not_json_message( mocked_websocket.__aiter__.return_value.append("not_valid_json") with pytest.raises(GraphQLClientInvalidMessageFormat): - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -178,7 +178,7 @@ async def test_execute_ws_raises_invalid_message_format_for_message_without_type mocked_websocket.__aiter__.return_value.append(json.dumps({"payload": {}})) with pytest.raises(GraphQLClientInvalidMessageFormat): - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -189,7 +189,7 @@ async def test_execute_ws_raises_invalid_message_format_for_message_with_invalid mocked_websocket.__aiter__.return_value.append(json.dumps({"type": "invalid_type"})) with pytest.raises(GraphQLClientInvalidMessageFormat): - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -202,7 +202,7 @@ async def test_execute_ws_raises_invalid_message_format_for_next_payload_without ) with pytest.raises(GraphQLClientInvalidMessageFormat): - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -215,7 +215,7 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type ) with pytest.raises(GraphQLClientGraphQLMultiError): - async for _ in AsyncBaseClientWithTelemetry().execute_ws(""): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -223,7 +223,7 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type def mocked_start_as_current_span(mocker): mocker_get_tracer = mocker.patch( "ariadne_codegen.client_generators.dependencies." - "async_base_client_with_telemetry.get_tracer" + "async_base_client_open_telemetry.get_tracer" ) return mocker_get_tracer.return_value.start_as_current_span @@ -232,7 +232,7 @@ def mocked_start_as_current_span(mocker): async def test_execute_ws_creates_root_span( mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument ): - client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(ws_url="ws://test_url", tracer="tracker") async for _ in client.execute_ws(""): pass @@ -246,7 +246,7 @@ async def test_execute_ws_creates_root_span( async def test_execute_ws_creates_root_span_with_custom_name( mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument ): - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( ws_url="ws://test_url", tracer="tracker", ws_root_span_name="ws root span" ) @@ -262,7 +262,7 @@ async def test_execute_ws_creates_root_span_with_custom_name( async def test_execute_ws_creates_root_span_with_custom_context( mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument ): - client = AsyncBaseClientWithTelemetry( + client = AsyncBaseClientOpenTelemetry( ws_url="ws://test_url", tracer="tracker", ws_root_context={"value": "test"} ) @@ -280,7 +280,7 @@ async def test_execute_ws_creates_root_span_with_custom_context( async def test_execute_ws_creates_span_for_init_message( mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): ): - client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(ws_url="ws://test_url", tracer="tracker") async for _ in client.execute_ws(""): pass @@ -294,7 +294,7 @@ async def test_execute_ws_creates_span_for_init_message( async def test_execute_ws_creates_span_for_subscribe_message( mocked_start_as_current_span, mocked_websocket # pylint: disable=unused-argument): ): - client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(ws_url="ws://test_url", tracer="tracker") async for _ in client.execute_ws( "subscription Abc(a: String, b: InputB) { value }", @@ -320,7 +320,7 @@ async def test_execute_ws_creates_span_for_received_message( received_message, mocked_websocket, mocked_start_as_current_span ): mocked_websocket.__aiter__.return_value.append(json.dumps(received_message)) - client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(ws_url="ws://test_url", tracer="tracker") async for _ in client.execute_ws(""): pass @@ -337,7 +337,7 @@ async def test_execute_ws_creates_span_for_received_error_message( mocked_websocket.__aiter__.return_value.append( json.dumps({"type": "error", "payload": [{"message": "error_message"}]}) ) - client = AsyncBaseClientWithTelemetry(ws_url="ws://test_url", tracer="tracker") + client = AsyncBaseClientOpenTelemetry(ws_url="ws://test_url", tracer="tracker") with pytest.raises(GraphQLClientGraphQLMultiError): async for _ in client.execute_ws(""): diff --git a/tests/test_settings.py b/tests/test_settings.py index d9b8a0cb..3a629287 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,9 +6,9 @@ from ariadne_codegen.client_generators.dependencies import ( async_base_client, - async_base_client_with_telemetry, + async_base_client_open_telemetry, base_client, - base_client_with_telemetry, + base_client_open_telemetry, ) from ariadne_codegen.config import ClientSettings, GraphQLSchemaSettings from ariadne_codegen.exceptions import InvalidConfiguration @@ -80,21 +80,21 @@ class BaseClient: @pytest.mark.parametrize( - "async_client, telemetry_client, expected_name, expected_path", + "async_client, opentelemetry_client, expected_name, expected_path", [ ( True, True, - "AsyncBaseClientWithTelemetry", - async_base_client_with_telemetry.__file__, + "AsyncBaseClientOpenTelemetry", + async_base_client_open_telemetry.__file__, ), (True, False, "AsyncBaseClient", async_base_client.__file__), - (False, True, "BaseClientWithTelemetry", base_client_with_telemetry.__file__), + (False, True, "BaseClientOpenTelemetry", base_client_open_telemetry.__file__), (False, False, "BaseClient", base_client.__file__), ], ) def test_client_settings_sets_correct_default_values_for_base_client_name_and_path( - tmp_path, async_client, telemetry_client, expected_name, expected_path + tmp_path, async_client, opentelemetry_client, expected_name, expected_path ): schema_path = tmp_path / "schema.graphql" schema_path.touch() @@ -105,7 +105,7 @@ def test_client_settings_sets_correct_default_values_for_base_client_name_and_pa schema_path=schema_path, queries_path=queries_path, async_client=async_client, - telemetry_client=telemetry_client, + opentelemetry_client=opentelemetry_client, ) assert settings.base_client_name == expected_name From 416c2270bd9a66000f71bcbcc40f853511736725 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 11 Oct 2023 14:30:01 +0200 Subject: [PATCH 9/9] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rafał Pitoń --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f1417539..2f6074c5 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ type = "Upload" When config option `opentelemetry_client` is set to `true` then default, included base client is replaced with one that implements the opt-in Open Telemetry support. By default this support does nothing but when the `opentelemetry-api` package is installed and the `tracer` argument is provided then the client will create spans with data about performed requests. -Telemetry arguments handled by `BaseClientOpenTelemetry`: +Tracing arguments handled by `BaseClientOpenTelemetry`: - `tracer`: `Optional[Union[str, Tracer]] = None` - tracer object or name which will be passed to the `get_tracer` method - `root_context`: `Optional[Context] = None` - optional context added to root span - `root_span_name`: `str = "GraphQL Operation"` - name of root span