Skip to content

Commit

Permalink
Improve APIClient and handling of sublattice dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
cjao committed Jan 5, 2024
1 parent c4b1319 commit 9af83a8
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 10 deletions.
10 changes: 5 additions & 5 deletions covalent/_api/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, dispatcher_addr: str, adapter: HTTPAdapter = None, auto_raise
self.adapter = adapter
self.auto_raise = auto_raise

def prepare_headers(self, **kwargs):
def prepare_headers(self, kwargs):
extra_headers = CovalentAPIClient.get_extra_headers()
headers = kwargs.get("headers", {})
if headers:
Expand All @@ -42,7 +42,7 @@ def prepare_headers(self, **kwargs):
return headers

def get(self, endpoint: str, **kwargs):
headers = self.prepare_headers(**kwargs)
headers = self.prepare_headers(kwargs)
url = self.dispatcher_addr + endpoint
try:
with requests.Session() as session:
Expand All @@ -62,7 +62,7 @@ def get(self, endpoint: str, **kwargs):
return r

def put(self, endpoint: str, **kwargs):
headers = self.prepare_headers()
headers = self.prepare_headers(kwargs)
url = self.dispatcher_addr + endpoint
try:
with requests.Session() as session:
Expand All @@ -81,7 +81,7 @@ def put(self, endpoint: str, **kwargs):
return r

def post(self, endpoint: str, **kwargs):
headers = self.prepare_headers()
headers = self.prepare_headers(kwargs)
url = self.dispatcher_addr + endpoint
try:
with requests.Session() as session:
Expand All @@ -100,7 +100,7 @@ def post(self, endpoint: str, **kwargs):
return r

def delete(self, endpoint: str, **kwargs):
headers = self.prepare_headers()
headers = self.prepare_headers(kwargs)

Check warning on line 103 in covalent/_api/apiclient.py

View check run for this annotation

Codecov / codecov/patch

covalent/_api/apiclient.py#L103

Added line #L103 was not covered by tests
url = self.dispatcher_addr + endpoint
try:
with requests.Session() as session:
Expand Down
11 changes: 7 additions & 4 deletions covalent/_dispatcher_plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
from copy import deepcopy
from functools import wraps
Expand Down Expand Up @@ -616,14 +617,16 @@ def _upload_asset(local_uri, remote_uri):
local_path = local_uri

with open(local_path, "rb") as reader:
app_log.debug(f"uploading to {remote_uri}")
content_length = os.path.getsize(local_path)
f = furl(remote_uri)
scheme = f.scheme
host = f.host
port = f.port
dispatcher_addr = f"{scheme}://{host}:{port}"
endpoint = str(f.path)
endpoint = f"{str(f.path)}?{str(f.query)}"
api_client = APIClient(dispatcher_addr)

r = api_client.put(endpoint, data=reader)
if content_length == 0:
r = api_client.put(endpoint, headers={"Content-Length": "0"}, data=reader.read())
else:
r = api_client.put(endpoint, data=reader)
r.raise_for_status()
3 changes: 2 additions & 1 deletion covalent/_workflow/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def _build_sublattice_graph(sub: Lattice, json_parent_metadata: str, *args, **kw
return recv_manifest.model_dump_json()

except Exception as ex:
if os.environ.get("COVALENT_DISABLE_LEGACY_SUBLATTICES") == "1":
raise
# Fall back to legacy sublattice handling
print("Falling back to legacy sublattice handling")
return sub.serialize_to_json()
44 changes: 44 additions & 0 deletions tests/covalent_tests/api/apiclient_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2023 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the Apache License 2.0 (the "License"). A copy of the
# License may be obtained with this software package or at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Use of this file is prohibited except in compliance with the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Unit tests for the API client"""

import json
from unittest.mock import MagicMock

import pytest

from covalent._api.apiclient import CovalentAPIClient


@pytest.fixture
def mock_session():
sess = MagicMock()


def test_header_injection(mocker):
extra_headers = {"x-custom-header": "value"}
headers = {"Content-Length": "128"}
expected_headers = headers.copy()
expected_headers.update(extra_headers)
mock_session = MagicMock()
environ = {"COVALENT_EXTRA_HEADERS": json.dumps(extra_headers)}
mocker.patch("os.environ", environ)
mocker.patch("requests.Session.__enter__", return_value=mock_session)

CovalentAPIClient("http://localhost").post("/docs", headers=headers)
mock_session.post.assert_called_with("http://localhost/docs", headers=expected_headers)
46 changes: 46 additions & 0 deletions tests/covalent_tests/workflow/electron_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,52 @@ def workflow(x):
assert parent_metadata[k] == lattice.metadata[k]


def test_build_sublattice_graph_prevent_fallback(mocker):
"""
Test preventing falling back to monolithic sublattice dispatch.
"""
dispatch_id = "test_build_sublattice_graph_prevent_fallback"

@ct.electron
def task(x):
return x

@ct.lattice
def workflow(x):
return task(x)

parent_metadata = {
"executor": "parent_executor",
"executor_data": {},
"workflow_executor": "my_postprocessor",
"workflow_executor_data": {},
"hooks": {
"deps": {"bash": None, "pip": None},
"call_before": [],
"call_after": [],
},
"triggers": "mock-trigger",
"qelectron_data_exists": False,
"results_dir": None,
}

# Omit the required environment variables
mock_environ = {"COVALENT_DISABLE_LEGACY_SUBLATTICES": "1"}

mock_reg = mocker.patch(
"covalent._dispatcher_plugins.local.LocalDispatcher.register_manifest",
)

mock_upload_assets = mocker.patch(
"covalent._dispatcher_plugins.local.LocalDispatcher.upload_assets",
)

mocker.patch("os.environ", mock_environ)

with pytest.raises(Exception):
json_lattice = _build_sublattice_graph(workflow, json.dumps(parent_metadata), 1)


def test_wait_for_building():
"""Test to check whether the graph is built correctly with `wait_for`."""

Expand Down

0 comments on commit 9af83a8

Please sign in to comment.