From 48ed0682dc0451430634e28d149812c7ce5671af Mon Sep 17 00:00:00 2001 From: mvdbeek Date: Tue, 22 Oct 2024 13:16:22 -0400 Subject: [PATCH] Move TRS import into WorkflowContentManager which fixes missing source metadata when importing a workflow through a landing page. --- lib/galaxy/app.py | 2 +- lib/galaxy/app_unittest_utils/galaxy_mock.py | 1 + lib/galaxy/managers/landing.py | 24 +---- lib/galaxy/managers/workflows.py | 96 +++++++++++++++----- lib/galaxy/model/__init__.py | 2 +- lib/galaxy/webapps/base/controller.py | 2 +- lib/galaxy/webapps/galaxy/api/workflows.py | 39 +++----- lib/galaxy/workflow/modules.py | 2 +- lib/galaxy/workflow/trs_proxy.py | 3 + lib/galaxy_test/api/test_workflows.py | 4 +- test/unit/app/managers/test_landing.py | 2 +- test/unit/workflows/test_modules.py | 4 +- test/unit/workflows/workflow_support.py | 2 + 13 files changed, 104 insertions(+), 79 deletions(-) diff --git a/lib/galaxy/app.py b/lib/galaxy/app.py index fcbd1a1b7cb0..285e2eed8a1c 100644 --- a/lib/galaxy/app.py +++ b/lib/galaxy/app.py @@ -666,6 +666,7 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl InstalledRepositoryManager, InstalledRepositoryManager(self) ) self.dynamic_tool_manager = self._register_singleton(DynamicToolManager) + self.trs_proxy = self._register_singleton(TrsProxy, TrsProxy(self.config)) self._configure_datatypes_registry( use_converters=use_converters, use_display_applications=use_display_applications, @@ -843,7 +844,6 @@ def __init__(self, **kwargs) -> None: # Must be initialized after job_config. self.workflow_scheduling_manager = scheduling_manager.WorkflowSchedulingManager(self) - self.trs_proxy = self._register_singleton(TrsProxy, TrsProxy(self.config)) # We need InteractiveToolManager before the job handler starts self.interactivetool_manager = InteractiveToolManager(self) # Start the job manager diff --git a/lib/galaxy/app_unittest_utils/galaxy_mock.py b/lib/galaxy/app_unittest_utils/galaxy_mock.py index f709f213ead4..ba89f2453b3e 100644 --- a/lib/galaxy/app_unittest_utils/galaxy_mock.py +++ b/lib/galaxy/app_unittest_utils/galaxy_mock.py @@ -156,6 +156,7 @@ def __init__(self, config=None, **kwargs) -> None: self.interactivetool_manager = Bunch(create_interactivetool=lambda *args, **kwargs: None) self.is_job_handler = False self.biotools_metadata_source = None + self.trs_proxy = Bunch() set_thread_app(self) def url_for(*args, **kwds): diff --git a/lib/galaxy/managers/landing.py b/lib/galaxy/managers/landing.py index 9facbac703c6..5012d2280134 100644 --- a/lib/galaxy/managers/landing.py +++ b/lib/galaxy/managers/landing.py @@ -4,7 +4,6 @@ ) from uuid import uuid4 -import yaml from pydantic import UUID4 from sqlalchemy import select @@ -15,10 +14,7 @@ ObjectNotFound, RequestParameterMissingException, ) -from galaxy.managers.workflows import ( - WorkflowContentsManager, - WorkflowCreateOptions, -) +from galaxy.managers.workflows import WorkflowContentsManager from galaxy.model import ( ToolLandingRequest as ToolLandingRequestModel, WorkflowLandingRequest as WorkflowLandingRequestModel, @@ -106,22 +102,10 @@ def _ensure_workflow(self, trans: ProvidesUserContext, request: WorkflowLandingR if request.workflow_source_type == "trs_url" and isinstance(trans.app, StructuredApp): # trans is always structured app except for unit test assert request.workflow_source - trs_id, trs_version = request.workflow_source.rsplit("/", 1) - _, trs_id, trs_version = trans.app.trs_proxy.get_trs_id_and_version_from_trs_url(request.workflow_source) - workflow = self.workflow_contents_manager.get_workflow_by_trs_id_and_version( - self.sa_session, trs_id=trs_id, trs_version=trs_version, user_id=trans.user and trans.user.id + workflow = self.workflow_contents_manager.get_or_create_workflow_from_trs( + trans, trs_url=request.workflow_source ) - if not workflow: - data = trans.app.trs_proxy.get_version_from_trs_url(request.workflow_source) - as_dict = yaml.safe_load(data) - raw_workflow_description = self.workflow_contents_manager.normalize_workflow_format(trans, as_dict) - created_workflow = self.workflow_contents_manager.build_workflow_from_raw_description( - trans, - raw_workflow_description, - WorkflowCreateOptions(), - ) - workflow = created_workflow.workflow - request.workflow_id = workflow.id + request.workflow_id = workflow.latest_workflow_id def get_tool_landing_request(self, trans: ProvidesUserContext, uuid: UUID4) -> ToolLandingRequest: request = self._get_claimed_tool_landing_request(trans, uuid) diff --git a/lib/galaxy/managers/workflows.py b/lib/galaxy/managers/workflows.py index d69e8afcab70..48be30372d67 100644 --- a/lib/galaxy/managers/workflows.py +++ b/lib/galaxy/managers/workflows.py @@ -89,7 +89,6 @@ text_column_filter, ) from galaxy.model.item_attrs import UsesAnnotations -from galaxy.model.scoped_session import galaxy_scoped_session from galaxy.schema.invocation import InvocationCancellationUserRequest from galaxy.schema.schema import WorkflowIndexQueryPayload from galaxy.structured_app import MinimalManagerApp @@ -139,6 +138,7 @@ attach_ordered_steps, has_cycles, ) +from galaxy.workflow.trs_proxy import TrsProxy log = logging.getLogger(__name__) @@ -598,8 +598,10 @@ def add_serializers(self): class WorkflowContentsManager(UsesAnnotations): - def __init__(self, app: MinimalManagerApp): + + def __init__(self, app: MinimalManagerApp, trs_proxy: TrsProxy): self.app = app + self.trs_proxy = trs_proxy self._resource_mapper_function = get_resource_mapper_function(app) def ensure_raw_description(self, dict_or_raw_description): @@ -814,18 +816,17 @@ def _workflow_from_raw_description( workflow.license = data.get("license") workflow.creator_metadata = data.get("creator") - if hasattr(workflow_state_resolution_options, "archive_source"): - if workflow_state_resolution_options.archive_source: - source_metadata = {} - if workflow_state_resolution_options.archive_source == "trs_tool": - source_metadata["trs_tool_id"] = workflow_state_resolution_options.trs_tool_id - source_metadata["trs_version_id"] = workflow_state_resolution_options.trs_version_id - source_metadata["trs_server"] = workflow_state_resolution_options.trs_server - source_metadata["trs_url"] = workflow_state_resolution_options.trs_url - elif not workflow_state_resolution_options.archive_source.startswith("file://"): # URL import - source_metadata["url"] = workflow_state_resolution_options.archive_source - workflow_state_resolution_options.archive_source = None # so trs_id is not set for subworkflows - workflow.source_metadata = source_metadata # type:ignore[assignment] + if getattr(workflow_state_resolution_options, "archive_source", None): + source_metadata = {} + if workflow_state_resolution_options.archive_source in ("trs_tool", "trs_url"): + source_metadata["trs_tool_id"] = workflow_state_resolution_options.trs_tool_id + source_metadata["trs_version_id"] = workflow_state_resolution_options.trs_version_id + source_metadata["trs_server"] = workflow_state_resolution_options.trs_server + source_metadata["trs_url"] = workflow_state_resolution_options.trs_url + elif not workflow_state_resolution_options.archive_source.startswith("file://"): # URL import + source_metadata["url"] = workflow_state_resolution_options.archive_source + workflow_state_resolution_options.archive_source = None # so trs_id is not set for subworkflows + workflow.source_metadata = source_metadata # Assume no errors until we find a step that has some workflow.has_errors = False @@ -2016,9 +2017,54 @@ def get_all_tools(self, workflow): tools.extend(self.get_all_tools(step.subworkflow)) return tools + def get_or_create_workflow_from_trs( + self, + trans: ProvidesUserContext, + trs_url: Optional[str], + trs_id: Optional[str] = None, + trs_version: Optional[str] = None, + trs_server: Optional[str] = None, + ): + user_id = trans.user and trans.user.id + assert user_id, "Cannot create workflow for anonymous user" + if not trs_url: + assert trs_server and trs_id and trs_version, "trs_url or trs_server, trs_version and trs_id must be passed" + server = self.trs_proxy.get_server(trs_server) + trs_url = server.get_trs_url(trs_id, trs_version) + else: + _, trs_id, trs_version = self.trs_proxy.get_trs_id_and_version_from_trs_url(trs_url=trs_url) + assert trs_id and trs_version and trs_url + + workflow = self.get_workflow_by_trs_id_and_version(trs_id=trs_id, trs_version=trs_version, user_id=user_id) + if not workflow: + workflow = self.create_workflow_from_trs_url(trans, trs_url, trs_server) + return workflow + + def create_workflow_from_trs_url( + self, trans: ProvidesUserContext, trs_url: str, trs_server: Optional[str] = None + ) -> StoredWorkflow: + _, trs_tool_id, trs_version_id = self.trs_proxy.get_trs_id_and_version_from_trs_url(trs_url=trs_url) + data = self.trs_proxy.get_version_from_trs_url(trs_url) + as_dict = yaml.safe_load(data) + raw_workflow_description = self.normalize_workflow_format(trans, as_dict) + created_workflow = self.build_workflow_from_raw_description( + trans, + raw_workflow_description, + WorkflowCreateOptions( + trs_tool_id=trs_tool_id, + trs_version_id=trs_version_id, + trs_url=trs_url, + trs_server=trs_server, + archive_source="trs_url", + ), + ) + return created_workflow.stored_workflow + def get_workflow_by_trs_id_and_version( - self, sa_session: galaxy_scoped_session, trs_id: str, trs_version: str, user_id: Optional[int] = None - ) -> Optional[model.Workflow]: + self, trs_id: str, trs_version: str, user_id: Optional[int] = None + ) -> Optional[model.StoredWorkflow]: + sa_session = self.app.model.session + def to_json(column, keys: List[str]): assert sa_session.bind if sa_session.bind.dialect.name == "postgresql": @@ -2028,12 +2074,12 @@ def to_json(column, keys: List[str]): return cast.astext else: for key in keys: - column = column.__getitem__(key) + column = func.json_extract(column, f"$.{key}") return column stmnt = ( - select(model.Workflow) - .join(model.StoredWorkflow, model.Workflow.stored_workflow_id == model.StoredWorkflow.id) + select(model.StoredWorkflow) + .join(model.Workflow, model.Workflow.id == model.StoredWorkflow.latest_workflow_id) .filter( and_( to_json(model.Workflow.source_metadata, ["trs_tool_id"]) == trs_id, @@ -2043,7 +2089,7 @@ def to_json(column, keys: List[str]): ) if user_id: stmnt = stmnt.filter( - model.StoredWorkflow.user_id == user_id, model.StoredWorkflow.latest_workflow_id == model.Workflow.id + model.StoredWorkflow.user_id == user_id, ) else: stmnt = stmnt.filter(model.StoredWorkflow.importable == true()) @@ -2104,11 +2150,11 @@ class WorkflowCreateOptions(WorkflowStateResolutionOptions): shed_tool_conf: Optional[str] = None # for workflows imported by archive source - archive_source: Optional[str] = "" - trs_tool_id: str = "" - trs_version_id: str = "" - trs_server: str = "" - trs_url: str = "" + archive_source: Optional[str] = None + trs_tool_id: Optional[str] = None + trs_version_id: Optional[str] = None + trs_server: Optional[str] = None + trs_url: Optional[str] = None @property def is_importable(self): diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index e97cf0438c2a..ffa93ecc7225 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -7817,7 +7817,7 @@ class Workflow(Base, Dictifiable, RepresentById): reports_config: Mapped[Optional[bytes]] = mapped_column(JSONType) creator_metadata: Mapped[Optional[bytes]] = mapped_column(JSONType) license: Mapped[Optional[str]] = mapped_column(TEXT) - source_metadata: Mapped[Optional[bytes]] = mapped_column(JSONType) + source_metadata: Mapped[Optional[Dict[str, str]]] = mapped_column(JSONType) uuid: Mapped[Optional[Union[UUID, str]]] = mapped_column(UUIDType) steps = relationship( diff --git a/lib/galaxy/webapps/base/controller.py b/lib/galaxy/webapps/base/controller.py index 59bb189d1952..9fdfa1079f6e 100644 --- a/lib/galaxy/webapps/base/controller.py +++ b/lib/galaxy/webapps/base/controller.py @@ -1003,7 +1003,7 @@ def _workflow_to_dict(self, trans, stored): """ Converts a workflow to a dict of attributes suitable for exporting. """ - workflow_contents_manager = workflows.WorkflowContentsManager(self.app) + workflow_contents_manager = workflows.WorkflowContentsManager(self.app, self.app.trs_proxy) return workflow_contents_manager.workflow_to_dict( trans, stored, diff --git a/lib/galaxy/webapps/galaxy/api/workflows.py b/lib/galaxy/webapps/galaxy/api/workflows.py index 6f94550e2134..1a8bafd38f90 100644 --- a/lib/galaxy/webapps/galaxy/api/workflows.py +++ b/lib/galaxy/webapps/galaxy/api/workflows.py @@ -245,29 +245,14 @@ def create(self, trans: GalaxyWebTransaction, payload=None, **kwd): payload["workflow"] = workflow_src return self.__api_import_new_workflow(trans, payload, **kwd) elif archive_source == "trs_tool": - server = None - trs_tool_id = None - trs_version_id = None - import_source = None - if "trs_url" in payload: - parts = self.app.trs_proxy.match_url( - payload["trs_url"], trans.app.config.fetch_url_allowlist_ips - ) - if parts: - server = self.app.trs_proxy.server_from_url(parts["trs_base_url"]) - trs_tool_id = parts["tool_id"] - trs_version_id = parts["version_id"] - payload["trs_tool_id"] = trs_tool_id - payload["trs_version_id"] = trs_version_id - else: - raise exceptions.RequestParameterInvalidException(f"Invalid TRS URL {payload['trs_url']}.") - else: - trs_server = payload.get("trs_server") - server = self.app.trs_proxy.get_server(trs_server) - trs_tool_id = payload.get("trs_tool_id") - trs_version_id = payload.get("trs_version_id") - - archive_data = server.get_version_descriptor(trs_tool_id, trs_version_id) + workflow = self.workflow_contents_manager.get_or_create_workflow_from_trs( + trans, + trs_url=payload.get("trs_url"), + trs_id=payload.get("trs_tool_id"), + trs_version=payload.get("trs_version_id"), + trs_server=payload.get("trs_server"), + ) + return self.__api_import_response(workflow) else: try: archive_data = stream_url_to_str( @@ -602,13 +587,15 @@ def __api_import_from_archive(self, trans: GalaxyWebTransaction, archive_data, s workflow, missing_tool_tups = self._workflow_from_dict( trans, raw_workflow_description, workflow_create_options, source=source ) - workflow_id = workflow.id - workflow = workflow.latest_workflow + return self.__api_import_response(workflow) + def __api_import_response(self, stored_workflow: model.StoredWorkflow): + workflow = stored_workflow.latest_workflow + assert workflow response = { "message": f"Workflow '{workflow.name}' imported successfully.", "status": "success", - "id": trans.security.encode_id(workflow_id), + "id": self.app.security.encode_id(stored_workflow.id), } if workflow.has_errors: response["message"] = "Imported, but some steps in this workflow have validation errors." diff --git a/lib/galaxy/workflow/modules.py b/lib/galaxy/workflow/modules.py index 448d07ecbe90..f210755d0482 100644 --- a/lib/galaxy/workflow/modules.py +++ b/lib/galaxy/workflow/modules.py @@ -715,7 +715,7 @@ def get_all_outputs(self, data_only=False): if hasattr(self.subworkflow, "workflow_outputs"): from galaxy.managers.workflows import WorkflowContentsManager - workflow_contents_manager = WorkflowContentsManager(self.trans.app) + workflow_contents_manager = WorkflowContentsManager(self.trans.app, self.trans.app.trs_proxy) subworkflow_dict = workflow_contents_manager._workflow_to_dict_editor( trans=self.trans, stored=self.subworkflow.stored_workflow, diff --git a/lib/galaxy/workflow/trs_proxy.py b/lib/galaxy/workflow/trs_proxy.py index 550238ffa229..bc4875b95f4b 100644 --- a/lib/galaxy/workflow/trs_proxy.py +++ b/lib/galaxy/workflow/trs_proxy.py @@ -140,6 +140,9 @@ def get_version_descriptor(self, tool_id, version_id, **kwd): ) return self._get(trs_api_url)["content"] + def get_trs_url(self, tool_id: str, version_id: str): + return f"{self._get_tool_api_endpoint(tool_id)}/versions/{version_id}" + def _quote(self, tool_id, **kwd): if asbool(kwd.get("tool_id_b64_encoded", False)): import base64 diff --git a/lib/galaxy_test/api/test_workflows.py b/lib/galaxy_test/api/test_workflows.py index 8408837df347..ec44b1e5a5ff 100644 --- a/lib/galaxy_test/api/test_workflows.py +++ b/lib/galaxy_test/api/test_workflows.py @@ -1234,7 +1234,7 @@ def test_trs_import_from_dockstore_trs_url(self): == "#workflow/github.com/jmchilton/galaxy-workflow-dockstore-example-1/mycoolworkflow" ) assert original_workflow.get("source_metadata").get("trs_version_id") == "master" - assert original_workflow.get("source_metadata").get("trs_server") == "" + assert not original_workflow.get("source_metadata").get("trs_server") assert original_workflow.get("source_metadata").get("trs_url") == ( "https://dockstore.org/api/ga4gh/trs/v2/tools/" "%23workflow%2Fgithub.com%2Fjmchilton%2Fgalaxy-workflow-dockstore-example-1%2Fmycoolworkflow/" @@ -1264,7 +1264,7 @@ def test_trs_import_from_workflowhub_trs_url(self): assert "COVID-19: variation analysis reporting" in original_workflow["name"] assert original_workflow.get("source_metadata").get("trs_tool_id") == "109" assert original_workflow.get("source_metadata").get("trs_version_id") == "5" - assert original_workflow.get("source_metadata").get("trs_server") == "" + assert not original_workflow.get("source_metadata").get("trs_server") assert ( original_workflow.get("source_metadata").get("trs_url") == "https://workflowhub.eu/ga4gh/trs/v2/tools/109/versions/5" diff --git a/test/unit/app/managers/test_landing.py b/test/unit/app/managers/test_landing.py index 46762c110a6d..f2ccb059b4bf 100644 --- a/test/unit/app/managers/test_landing.py +++ b/test/unit/app/managers/test_landing.py @@ -41,7 +41,7 @@ class TestLanding(BaseTestCase): def setUp(self): super().setUp() - self.workflow_contents_manager = WorkflowContentsManager(self.app) + self.workflow_contents_manager = WorkflowContentsManager(self.app, self.app.trs_proxy) self.landing_manager = LandingRequestManager( self.trans.sa_session, self.app.security, self.workflow_contents_manager ) diff --git a/test/unit/workflows/test_modules.py b/test/unit/workflows/test_modules.py index 6a1b5b653ce7..8052c6b11950 100644 --- a/test/unit/workflows/test_modules.py +++ b/test/unit/workflows/test_modules.py @@ -428,7 +428,9 @@ def _output_step(step_input_def, step_output_def) -> Dict[str, Any]: @pytest.mark.parametrize("test_case", _construct_steps_for_map_over()) def test_subworkflow_map_over_type(test_case): trans = MockTrans() - new_steps = WorkflowContentsManager(app=trans.app)._resolve_collection_type(test_case.steps) + new_steps = WorkflowContentsManager(app=trans.app, trs_proxy=trans.app.trs_proxy)._resolve_collection_type( + test_case.steps + ) assert ( new_steps[1]["outputs"][0].get("collection_type") == test_case.expected_collection_type ), "Expected collection_type '{}' for a '{}' input module, a '{}' input and a '{}' output, got collection_type '{}' instead".format( diff --git a/test/unit/workflows/workflow_support.py b/test/unit/workflows/workflow_support.py index 05064e722fac..a3b8374637df 100644 --- a/test/unit/workflows/workflow_support.py +++ b/test/unit/workflows/workflow_support.py @@ -6,12 +6,14 @@ from galaxy.app_unittest_utils import galaxy_mock from galaxy.managers.workflows import WorkflowsManager from galaxy.model.base import transaction +from galaxy.util.bunch import Bunch from galaxy.workflow.modules import module_factory class MockTrans: def __init__(self): self.app = MockApp() + self.app.trs_proxy = Bunch() self.sa_session = self.app.model.context self._user = None