From f302c552c0a428cf3ff1aa2e0f1960421f438121 Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Wed, 3 Jan 2024 21:06:32 -0500 Subject: [PATCH] Improve APIClient and handling of sublattice dispatches --- covalent/_api/apiclient.py | 10 ++--- covalent/_dispatcher_plugins/local.py | 11 ++++-- covalent/_workflow/electron.py | 3 +- tests/covalent_tests/api/apiclient_test.py | 44 ++++++++++++++++++++++ 4 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 tests/covalent_tests/api/apiclient_test.py diff --git a/covalent/_api/apiclient.py b/covalent/_api/apiclient.py index c4c2a5492..d3be6bd4a 100644 --- a/covalent/_api/apiclient.py +++ b/covalent/_api/apiclient.py @@ -33,7 +33,7 @@ def __init__(self, dispatcher_addr: str, adapter: HTTPAdapter = None, auto_raise self.adapter = adapter self.auto_raise = auto_raise - def prepare_headers(self, **kwargs): + def prepare_headers(self, kwargs): extra_headers = CovalentAPIClient.get_extra_headers() headers = kwargs.get("headers", {}) if headers: @@ -42,7 +42,7 @@ def prepare_headers(self, **kwargs): return headers def get(self, endpoint: str, **kwargs): - headers = self.prepare_headers(**kwargs) + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: @@ -62,7 +62,7 @@ def get(self, endpoint: str, **kwargs): return r def put(self, endpoint: str, **kwargs): - headers = self.prepare_headers() + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: @@ -81,7 +81,7 @@ def put(self, endpoint: str, **kwargs): return r def post(self, endpoint: str, **kwargs): - headers = self.prepare_headers() + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: @@ -100,7 +100,7 @@ def post(self, endpoint: str, **kwargs): return r def delete(self, endpoint: str, **kwargs): - headers = self.prepare_headers() + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index 8760cec96..f5ad26bdd 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import tempfile from copy import deepcopy from functools import wraps @@ -616,14 +617,16 @@ def _upload_asset(local_uri, remote_uri): local_path = local_uri with open(local_path, "rb") as reader: - app_log.debug(f"uploading to {remote_uri}") + content_length = os.path.getsize(local_path) f = furl(remote_uri) scheme = f.scheme host = f.host port = f.port dispatcher_addr = f"{scheme}://{host}:{port}" - endpoint = str(f.path) + endpoint = f"{str(f.path)}?{str(f.query)}" api_client = APIClient(dispatcher_addr) - - r = api_client.put(endpoint, data=reader) + if content_length == 0: + r = api_client.put(endpoint, headers={"Content-Length": "0"}, data=reader.read()) + else: + r = api_client.put(endpoint, data=reader) r.raise_for_status() diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index 12f18cbf5..b38ea3d5e 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -915,6 +915,7 @@ def _build_sublattice_graph(sub: Lattice, json_parent_metadata: str, *args, **kw return recv_manifest.model_dump_json() except Exception as ex: + if os.environ.get("COVALENT_DISABLE_LEGACY_SUBLATTICES") == "1": + raise # Fall back to legacy sublattice handling - print("Falling back to legacy sublattice handling") return sub.serialize_to_json() diff --git a/tests/covalent_tests/api/apiclient_test.py b/tests/covalent_tests/api/apiclient_test.py new file mode 100644 index 000000000..20d22cb5a --- /dev/null +++ b/tests/covalent_tests/api/apiclient_test.py @@ -0,0 +1,44 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Unit tests for the API client""" + +import json +from unittest.mock import MagicMock + +import pytest + +from covalent._api.apiclient import CovalentAPIClient + + +@pytest.fixture +def mock_session(): + sess = MagicMock() + + +def test_header_injection(mocker): + extra_headers = {"x-custom-header": "value"} + headers = {"Content-Length": "128"} + expected_headers = headers.copy() + expected_headers.update(extra_headers) + mock_session = MagicMock() + environ = {"COVALENT_EXTRA_HEADERS": json.dumps(extra_headers)} + mocker.patch("os.environ", environ) + mocker.patch("requests.Session.__enter__", return_value=mock_session) + + CovalentAPIClient("http://localhost").post("/docs", headers=headers) + mock_session.post.assert_called_with("http://localhost/docs", headers=expected_headers)