Skip to content

Commit

Permalink
Merge pull request galaxyproject#18755 from nsoranzo/fix_B039
Browse files Browse the repository at this point in the history
Fix new flake8-bugbear B039 and mypy type-var errors
  • Loading branch information
jmchilton authored Aug 31, 2024
2 parents d1fc599 + be29cca commit 9f3ed27
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
10 changes: 5 additions & 5 deletions lib/galaxy/dependencies/pinned-lint-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
attrs==23.2.0
flake8==7.1.0
flake8-bugbear==24.4.26
attrs==24.2.0
flake8==7.1.1
flake8-bugbear==24.8.19
mccabe==0.7.0
pycodestyle==2.12.0
pycodestyle==2.12.1
pyflakes==3.2.0
ruff==0.6.1
ruff==0.6.3
7 changes: 3 additions & 4 deletions lib/galaxy/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
# of a request (which run within a threadpool) to see changes to the ContextVar
# state. See https://github.com/tiangolo/fastapi/issues/953#issuecomment-586006249
# for details
_request_state: Dict[str, str] = {}
REQUEST_ID = ContextVar("request_id", default=_request_state.copy())
REQUEST_ID: ContextVar[Union[Dict[str, str], None]] = ContextVar("request_id", default=None)


@contextlib.contextmanager
Expand Down Expand Up @@ -112,12 +111,12 @@ def new_session(self):

def request_scopefunc(self):
"""
Return a value that is used as dictionary key for sqlalchemy's ScopedRegistry.
Return a value that is used as dictionary key for SQLAlchemy's ScopedRegistry.
This ensures that threads or request contexts will receive a single identical session
from the ScopedRegistry.
"""
return REQUEST_ID.get().get("request") or threading.get_ident()
return REQUEST_ID.get({}).get("request") or threading.get_ident()

@staticmethod
def set_request_id(request_id):
Expand Down
14 changes: 7 additions & 7 deletions lib/galaxy/tool_util/verify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ def files_contains(file1, file2, attributes=None):


def _singleobject_intersection_over_union(
mask1: "numpy.typing.NDArray",
mask2: "numpy.typing.NDArray",
mask1: "numpy.typing.NDArray[numpy.bool_]",
mask2: "numpy.typing.NDArray[numpy.bool_]",
) -> "numpy.floating":
return numpy.logical_and(mask1, mask2).sum() / numpy.logical_or(mask1, mask2).sum()

Expand All @@ -483,7 +483,7 @@ def _multiobject_intersection_over_union(
pin_labels: Optional[List[int]] = None,
repeat_reverse: bool = True,
) -> List["numpy.floating"]:
iou_list = []
iou_list: List[numpy.floating] = []
for label1 in numpy.unique(mask1):
cc1 = mask1 == label1

Expand All @@ -494,13 +494,13 @@ def _multiobject_intersection_over_union(

# Otherwise, use the object with the largest IoU value, excluding the pinned labels.
else:
cc1_iou_list = []
cc1_iou_list: List[numpy.floating] = []
for label2 in numpy.unique(mask2[cc1]):
if pin_labels is not None and label2 in pin_labels:
continue
cc2 = mask2 == label2
cc1_iou_list.append(_singleobject_intersection_over_union(cc1, cc2))
iou_list.append(max(cc1_iou_list))
iou_list.append(max(cc1_iou_list)) # type: ignore[type-var, unused-ignore] # https://github.com/python/typeshed/issues/12562

if repeat_reverse:
iou_list.extend(_multiobject_intersection_over_union(mask2, mask1, pin_labels, repeat_reverse=False))
Expand All @@ -511,7 +511,7 @@ def _multiobject_intersection_over_union(
def intersection_over_union(
mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray", pin_labels: Optional[List[int]] = None
) -> "numpy.floating":
"""Compute the intersection over union (IoU) for the objects in two masks containing lables.
"""Compute the intersection over union (IoU) for the objects in two masks containing labels.
The IoU is computed for each uniquely labeled image region (object), and the overall minimum value is returned (i.e. the worst value).
To compute the IoU for each object, the corresponding object in the other mask needs to be determined.
Expand All @@ -529,7 +529,7 @@ def intersection_over_union(
count = sum(label in mask for mask in (mask1, mask2))
count_str = {1: "one", 2: "both"}
assert count == 2, f"Label {label} is pinned but missing in {count_str[2 - count]} of the images."
return min(_multiobject_intersection_over_union(mask1, mask2, pin_labels))
return min(_multiobject_intersection_over_union(mask1, mask2, pin_labels)) # type: ignore[type-var, unused-ignore] # https://github.com/python/typeshed/issues/12562


def _parse_label_list(label_list_str: Optional[str]) -> List[int]:
Expand Down
16 changes: 10 additions & 6 deletions test/unit/webapps/test_request_scoped_sqlalchemy_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
import pytest
from fastapi import FastAPI
from fastapi.param_functions import Depends
from httpx import AsyncClient
from httpx import (
ASGITransport,
AsyncClient,
)
from starlette_context import context as request_context

from galaxy.app_unittest_utils.galaxy_mock import MockApp
from galaxy.webapps.base.api import add_request_id_middleware

app = FastAPI()
add_request_id_middleware(app)
transport = ASGITransport(app=app)
GX_APP = None


Expand Down Expand Up @@ -96,7 +100,7 @@ def assert_scoped_session_is_thread_local(gx_app):

@pytest.mark.asyncio
async def test_request_scoped_sa_session_single_request():
async with AsyncClient(app=app, base_url="http://test") as client:
async with AsyncClient(base_url="http://test", transport=transport) as client:
response = await client.get("/")
assert response.status_code == 200
assert response.json() == {"msg": "Hello World"}
Expand All @@ -106,7 +110,7 @@ async def test_request_scoped_sa_session_single_request():

@pytest.mark.asyncio
async def test_request_scoped_sa_session_exception():
async with AsyncClient(app=app, base_url="http://test") as client:
async with AsyncClient(base_url="http://test", transport=transport) as client:
with pytest.raises(UnexpectedException):
await client.get("/internal_server_error")
assert GX_APP
Expand All @@ -115,7 +119,7 @@ async def test_request_scoped_sa_session_exception():

@pytest.mark.asyncio
async def test_request_scoped_sa_session_concurrent_requests_sync():
async with AsyncClient(app=app, base_url="http://test") as client:
async with AsyncClient(base_url="http://test", transport=transport) as client:
awaitables = (client.get("/sync_wait") for _ in range(10))
result = await asyncio.gather(*awaitables)
uuids = []
Expand All @@ -129,7 +133,7 @@ async def test_request_scoped_sa_session_concurrent_requests_sync():

@pytest.mark.asyncio
async def test_request_scoped_sa_session_concurrent_requests_async():
async with AsyncClient(app=app, base_url="http://test") as client:
async with AsyncClient(base_url="http://test", transport=transport) as client:
awaitables = (client.get("/async_wait") for _ in range(10))
result = await asyncio.gather(*awaitables)
uuids = []
Expand All @@ -147,7 +151,7 @@ async def test_request_scoped_sa_session_concurrent_requests_and_background_thre
target = functools.partial(assert_scoped_session_is_thread_local, GX_APP)
with concurrent.futures.ThreadPoolExecutor() as pool:
background_pool = loop.run_in_executor(pool, target)
async with AsyncClient(app=app, base_url="http://test") as client:
async with AsyncClient(base_url="http://test", transport=transport) as client:
awaitables = (client.get("/async_wait") for _ in range(10))
result = await asyncio.gather(*awaitables)
uuids = []
Expand Down

0 comments on commit 9f3ed27

Please sign in to comment.