Skip to content

Commit

Permalink
Merge pull request #19070 from mvdbeek/reuse_trs_import_in_workflow_c…
Browse files Browse the repository at this point in the history
…ontents_manager

Move TRS import into WorkflowContentManager
  • Loading branch information
mvdbeek authored Oct 30, 2024
2 parents 03307f6 + 48ed068 commit 339724e
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 79 deletions.
2 changes: 1 addition & 1 deletion lib/galaxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl
InstalledRepositoryManager, InstalledRepositoryManager(self)
)
self.dynamic_tool_manager = self._register_singleton(DynamicToolManager)
self.trs_proxy = self._register_singleton(TrsProxy, TrsProxy(self.config))
self._configure_datatypes_registry(
use_converters=use_converters,
use_display_applications=use_display_applications,
Expand Down Expand Up @@ -843,7 +844,6 @@ def __init__(self, **kwargs) -> None:
# Must be initialized after job_config.
self.workflow_scheduling_manager = scheduling_manager.WorkflowSchedulingManager(self)

self.trs_proxy = self._register_singleton(TrsProxy, TrsProxy(self.config))
# We need InteractiveToolManager before the job handler starts
self.interactivetool_manager = InteractiveToolManager(self)
# Start the job manager
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy/app_unittest_utils/galaxy_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(self, config=None, **kwargs) -> None:
self.interactivetool_manager = Bunch(create_interactivetool=lambda *args, **kwargs: None)
self.is_job_handler = False
self.biotools_metadata_source = None
self.trs_proxy = Bunch()
set_thread_app(self)

def url_for(*args, **kwds):
Expand Down
24 changes: 4 additions & 20 deletions lib/galaxy/managers/landing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
)
from uuid import uuid4

import yaml
from pydantic import UUID4
from sqlalchemy import select

Expand All @@ -15,10 +14,7 @@
ObjectNotFound,
RequestParameterMissingException,
)
from galaxy.managers.workflows import (
WorkflowContentsManager,
WorkflowCreateOptions,
)
from galaxy.managers.workflows import WorkflowContentsManager
from galaxy.model import (
ToolLandingRequest as ToolLandingRequestModel,
WorkflowLandingRequest as WorkflowLandingRequestModel,
Expand Down Expand Up @@ -106,22 +102,10 @@ def _ensure_workflow(self, trans: ProvidesUserContext, request: WorkflowLandingR
if request.workflow_source_type == "trs_url" and isinstance(trans.app, StructuredApp):
# trans is always structured app except for unit test
assert request.workflow_source
trs_id, trs_version = request.workflow_source.rsplit("/", 1)
_, trs_id, trs_version = trans.app.trs_proxy.get_trs_id_and_version_from_trs_url(request.workflow_source)
workflow = self.workflow_contents_manager.get_workflow_by_trs_id_and_version(
self.sa_session, trs_id=trs_id, trs_version=trs_version, user_id=trans.user and trans.user.id
workflow = self.workflow_contents_manager.get_or_create_workflow_from_trs(
trans, trs_url=request.workflow_source
)
if not workflow:
data = trans.app.trs_proxy.get_version_from_trs_url(request.workflow_source)
as_dict = yaml.safe_load(data)
raw_workflow_description = self.workflow_contents_manager.normalize_workflow_format(trans, as_dict)
created_workflow = self.workflow_contents_manager.build_workflow_from_raw_description(
trans,
raw_workflow_description,
WorkflowCreateOptions(),
)
workflow = created_workflow.workflow
request.workflow_id = workflow.id
request.workflow_id = workflow.latest_workflow_id

def get_tool_landing_request(self, trans: ProvidesUserContext, uuid: UUID4) -> ToolLandingRequest:
request = self._get_claimed_tool_landing_request(trans, uuid)
Expand Down
96 changes: 71 additions & 25 deletions lib/galaxy/managers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
text_column_filter,
)
from galaxy.model.item_attrs import UsesAnnotations
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.invocation import InvocationCancellationUserRequest
from galaxy.schema.schema import WorkflowIndexQueryPayload
from galaxy.structured_app import MinimalManagerApp
Expand Down Expand Up @@ -139,6 +138,7 @@
attach_ordered_steps,
has_cycles,
)
from galaxy.workflow.trs_proxy import TrsProxy

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -598,8 +598,10 @@ def add_serializers(self):


class WorkflowContentsManager(UsesAnnotations):
def __init__(self, app: MinimalManagerApp):

def __init__(self, app: MinimalManagerApp, trs_proxy: TrsProxy):
self.app = app
self.trs_proxy = trs_proxy
self._resource_mapper_function = get_resource_mapper_function(app)

def ensure_raw_description(self, dict_or_raw_description):
Expand Down Expand Up @@ -814,18 +816,17 @@ def _workflow_from_raw_description(
workflow.license = data.get("license")
workflow.creator_metadata = data.get("creator")

if hasattr(workflow_state_resolution_options, "archive_source"):
if workflow_state_resolution_options.archive_source:
source_metadata = {}
if workflow_state_resolution_options.archive_source == "trs_tool":
source_metadata["trs_tool_id"] = workflow_state_resolution_options.trs_tool_id
source_metadata["trs_version_id"] = workflow_state_resolution_options.trs_version_id
source_metadata["trs_server"] = workflow_state_resolution_options.trs_server
source_metadata["trs_url"] = workflow_state_resolution_options.trs_url
elif not workflow_state_resolution_options.archive_source.startswith("file://"): # URL import
source_metadata["url"] = workflow_state_resolution_options.archive_source
workflow_state_resolution_options.archive_source = None # so trs_id is not set for subworkflows
workflow.source_metadata = source_metadata # type:ignore[assignment]
if getattr(workflow_state_resolution_options, "archive_source", None):
source_metadata = {}
if workflow_state_resolution_options.archive_source in ("trs_tool", "trs_url"):
source_metadata["trs_tool_id"] = workflow_state_resolution_options.trs_tool_id
source_metadata["trs_version_id"] = workflow_state_resolution_options.trs_version_id
source_metadata["trs_server"] = workflow_state_resolution_options.trs_server
source_metadata["trs_url"] = workflow_state_resolution_options.trs_url
elif not workflow_state_resolution_options.archive_source.startswith("file://"): # URL import
source_metadata["url"] = workflow_state_resolution_options.archive_source
workflow_state_resolution_options.archive_source = None # so trs_id is not set for subworkflows
workflow.source_metadata = source_metadata

# Assume no errors until we find a step that has some
workflow.has_errors = False
Expand Down Expand Up @@ -2016,9 +2017,54 @@ def get_all_tools(self, workflow):
tools.extend(self.get_all_tools(step.subworkflow))
return tools

def get_or_create_workflow_from_trs(
self,
trans: ProvidesUserContext,
trs_url: Optional[str],
trs_id: Optional[str] = None,
trs_version: Optional[str] = None,
trs_server: Optional[str] = None,
):
user_id = trans.user and trans.user.id
assert user_id, "Cannot create workflow for anonymous user"
if not trs_url:
assert trs_server and trs_id and trs_version, "trs_url or trs_server, trs_version and trs_id must be passed"
server = self.trs_proxy.get_server(trs_server)
trs_url = server.get_trs_url(trs_id, trs_version)
else:
_, trs_id, trs_version = self.trs_proxy.get_trs_id_and_version_from_trs_url(trs_url=trs_url)
assert trs_id and trs_version and trs_url

workflow = self.get_workflow_by_trs_id_and_version(trs_id=trs_id, trs_version=trs_version, user_id=user_id)
if not workflow:
workflow = self.create_workflow_from_trs_url(trans, trs_url, trs_server)
return workflow

def create_workflow_from_trs_url(
self, trans: ProvidesUserContext, trs_url: str, trs_server: Optional[str] = None
) -> StoredWorkflow:
_, trs_tool_id, trs_version_id = self.trs_proxy.get_trs_id_and_version_from_trs_url(trs_url=trs_url)
data = self.trs_proxy.get_version_from_trs_url(trs_url)
as_dict = yaml.safe_load(data)
raw_workflow_description = self.normalize_workflow_format(trans, as_dict)
created_workflow = self.build_workflow_from_raw_description(
trans,
raw_workflow_description,
WorkflowCreateOptions(
trs_tool_id=trs_tool_id,
trs_version_id=trs_version_id,
trs_url=trs_url,
trs_server=trs_server,
archive_source="trs_url",
),
)
return created_workflow.stored_workflow

def get_workflow_by_trs_id_and_version(
self, sa_session: galaxy_scoped_session, trs_id: str, trs_version: str, user_id: Optional[int] = None
) -> Optional[model.Workflow]:
self, trs_id: str, trs_version: str, user_id: Optional[int] = None
) -> Optional[model.StoredWorkflow]:
sa_session = self.app.model.session

def to_json(column, keys: List[str]):
assert sa_session.bind
if sa_session.bind.dialect.name == "postgresql":
Expand All @@ -2028,12 +2074,12 @@ def to_json(column, keys: List[str]):
return cast.astext
else:
for key in keys:
column = column.__getitem__(key)
column = func.json_extract(column, f"$.{key}")
return column

stmnt = (
select(model.Workflow)
.join(model.StoredWorkflow, model.Workflow.stored_workflow_id == model.StoredWorkflow.id)
select(model.StoredWorkflow)
.join(model.Workflow, model.Workflow.id == model.StoredWorkflow.latest_workflow_id)
.filter(
and_(
to_json(model.Workflow.source_metadata, ["trs_tool_id"]) == trs_id,
Expand All @@ -2043,7 +2089,7 @@ def to_json(column, keys: List[str]):
)
if user_id:
stmnt = stmnt.filter(
model.StoredWorkflow.user_id == user_id, model.StoredWorkflow.latest_workflow_id == model.Workflow.id
model.StoredWorkflow.user_id == user_id,
)
else:
stmnt = stmnt.filter(model.StoredWorkflow.importable == true())
Expand Down Expand Up @@ -2104,11 +2150,11 @@ class WorkflowCreateOptions(WorkflowStateResolutionOptions):
shed_tool_conf: Optional[str] = None

# for workflows imported by archive source
archive_source: Optional[str] = ""
trs_tool_id: str = ""
trs_version_id: str = ""
trs_server: str = ""
trs_url: str = ""
archive_source: Optional[str] = None
trs_tool_id: Optional[str] = None
trs_version_id: Optional[str] = None
trs_server: Optional[str] = None
trs_url: Optional[str] = None

@property
def is_importable(self):
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7827,7 +7827,7 @@ class Workflow(Base, Dictifiable, RepresentById):
reports_config: Mapped[Optional[bytes]] = mapped_column(JSONType)
creator_metadata: Mapped[Optional[bytes]] = mapped_column(JSONType)
license: Mapped[Optional[str]] = mapped_column(TEXT)
source_metadata: Mapped[Optional[bytes]] = mapped_column(JSONType)
source_metadata: Mapped[Optional[Dict[str, str]]] = mapped_column(JSONType)
uuid: Mapped[Optional[Union[UUID, str]]] = mapped_column(UUIDType)

steps = relationship(
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _workflow_to_dict(self, trans, stored):
"""
Converts a workflow to a dict of attributes suitable for exporting.
"""
workflow_contents_manager = workflows.WorkflowContentsManager(self.app)
workflow_contents_manager = workflows.WorkflowContentsManager(self.app, self.app.trs_proxy)
return workflow_contents_manager.workflow_to_dict(
trans,
stored,
Expand Down
39 changes: 13 additions & 26 deletions lib/galaxy/webapps/galaxy/api/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,29 +246,14 @@ def create(self, trans: GalaxyWebTransaction, payload=None, **kwd):
payload["workflow"] = workflow_src
return self.__api_import_new_workflow(trans, payload, **kwd)
elif archive_source == "trs_tool":
server = None
trs_tool_id = None
trs_version_id = None
import_source = None
if "trs_url" in payload:
parts = self.app.trs_proxy.match_url(
payload["trs_url"], trans.app.config.fetch_url_allowlist_ips
)
if parts:
server = self.app.trs_proxy.server_from_url(parts["trs_base_url"])
trs_tool_id = parts["tool_id"]
trs_version_id = parts["version_id"]
payload["trs_tool_id"] = trs_tool_id
payload["trs_version_id"] = trs_version_id
else:
raise exceptions.RequestParameterInvalidException(f"Invalid TRS URL {payload['trs_url']}.")
else:
trs_server = payload.get("trs_server")
server = self.app.trs_proxy.get_server(trs_server)
trs_tool_id = payload.get("trs_tool_id")
trs_version_id = payload.get("trs_version_id")

archive_data = server.get_version_descriptor(trs_tool_id, trs_version_id)
workflow = self.workflow_contents_manager.get_or_create_workflow_from_trs(
trans,
trs_url=payload.get("trs_url"),
trs_id=payload.get("trs_tool_id"),
trs_version=payload.get("trs_version_id"),
trs_server=payload.get("trs_server"),
)
return self.__api_import_response(workflow)
else:
try:
archive_data = stream_url_to_str(
Expand Down Expand Up @@ -603,13 +588,15 @@ def __api_import_from_archive(self, trans: GalaxyWebTransaction, archive_data, s
workflow, missing_tool_tups = self._workflow_from_dict(
trans, raw_workflow_description, workflow_create_options, source=source
)
workflow_id = workflow.id
workflow = workflow.latest_workflow
return self.__api_import_response(workflow)

def __api_import_response(self, stored_workflow: model.StoredWorkflow):
workflow = stored_workflow.latest_workflow
assert workflow
response = {
"message": f"Workflow '{workflow.name}' imported successfully.",
"status": "success",
"id": trans.security.encode_id(workflow_id),
"id": self.app.security.encode_id(stored_workflow.id),
}
if workflow.has_errors:
response["message"] = "Imported, but some steps in this workflow have validation errors."
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def get_all_outputs(self, data_only=False):
if hasattr(self.subworkflow, "workflow_outputs"):
from galaxy.managers.workflows import WorkflowContentsManager

workflow_contents_manager = WorkflowContentsManager(self.trans.app)
workflow_contents_manager = WorkflowContentsManager(self.trans.app, self.trans.app.trs_proxy)
subworkflow_dict = workflow_contents_manager._workflow_to_dict_editor(
trans=self.trans,
stored=self.subworkflow.stored_workflow,
Expand Down
3 changes: 3 additions & 0 deletions lib/galaxy/workflow/trs_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def get_version_descriptor(self, tool_id, version_id, **kwd):
)
return self._get(trs_api_url)["content"]

def get_trs_url(self, tool_id: str, version_id: str):
return f"{self._get_tool_api_endpoint(tool_id)}/versions/{version_id}"

def _quote(self, tool_id, **kwd):
if asbool(kwd.get("tool_id_b64_encoded", False)):
import base64
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy_test/api/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ def test_trs_import_from_dockstore_trs_url(self):
== "#workflow/github.com/jmchilton/galaxy-workflow-dockstore-example-1/mycoolworkflow"
)
assert original_workflow.get("source_metadata").get("trs_version_id") == "master"
assert original_workflow.get("source_metadata").get("trs_server") == ""
assert not original_workflow.get("source_metadata").get("trs_server")
assert original_workflow.get("source_metadata").get("trs_url") == (
"https://dockstore.org/api/ga4gh/trs/v2/tools/"
"%23workflow%2Fgithub.com%2Fjmchilton%2Fgalaxy-workflow-dockstore-example-1%2Fmycoolworkflow/"
Expand Down Expand Up @@ -1264,7 +1264,7 @@ def test_trs_import_from_workflowhub_trs_url(self):
assert "COVID-19: variation analysis reporting" in original_workflow["name"]
assert original_workflow.get("source_metadata").get("trs_tool_id") == "109"
assert original_workflow.get("source_metadata").get("trs_version_id") == "5"
assert original_workflow.get("source_metadata").get("trs_server") == ""
assert not original_workflow.get("source_metadata").get("trs_server")
assert (
original_workflow.get("source_metadata").get("trs_url")
== "https://workflowhub.eu/ga4gh/trs/v2/tools/109/versions/5"
Expand Down
2 changes: 1 addition & 1 deletion test/unit/app/managers/test_landing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TestLanding(BaseTestCase):

def setUp(self):
super().setUp()
self.workflow_contents_manager = WorkflowContentsManager(self.app)
self.workflow_contents_manager = WorkflowContentsManager(self.app, self.app.trs_proxy)
self.landing_manager = LandingRequestManager(
self.trans.sa_session, self.app.security, self.workflow_contents_manager
)
Expand Down
4 changes: 3 additions & 1 deletion test/unit/workflows/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ def _output_step(step_input_def, step_output_def) -> Dict[str, Any]:
@pytest.mark.parametrize("test_case", _construct_steps_for_map_over())
def test_subworkflow_map_over_type(test_case):
trans = MockTrans()
new_steps = WorkflowContentsManager(app=trans.app)._resolve_collection_type(test_case.steps)
new_steps = WorkflowContentsManager(app=trans.app, trs_proxy=trans.app.trs_proxy)._resolve_collection_type(
test_case.steps
)
assert (
new_steps[1]["outputs"][0].get("collection_type") == test_case.expected_collection_type
), "Expected collection_type '{}' for a '{}' input module, a '{}' input and a '{}' output, got collection_type '{}' instead".format(
Expand Down
2 changes: 2 additions & 0 deletions test/unit/workflows/workflow_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from galaxy.app_unittest_utils import galaxy_mock
from galaxy.managers.workflows import WorkflowsManager
from galaxy.model.base import transaction
from galaxy.util.bunch import Bunch
from galaxy.workflow.modules import module_factory


class MockTrans:
def __init__(self):
self.app = MockApp()
self.app.trs_proxy = Bunch()
self.sa_session = self.app.model.context
self._user = None

Expand Down

0 comments on commit 339724e

Please sign in to comment.