Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move TRS import into WorkflowContentManager #19070

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/galaxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl
InstalledRepositoryManager, InstalledRepositoryManager(self)
)
self.dynamic_tool_manager = self._register_singleton(DynamicToolManager)
self.trs_proxy = self._register_singleton(TrsProxy, TrsProxy(self.config))
self._configure_datatypes_registry(
use_converters=use_converters,
use_display_applications=use_display_applications,
Expand Down Expand Up @@ -843,7 +844,6 @@ def __init__(self, **kwargs) -> None:
# Must be initialized after job_config.
self.workflow_scheduling_manager = scheduling_manager.WorkflowSchedulingManager(self)

self.trs_proxy = self._register_singleton(TrsProxy, TrsProxy(self.config))
# We need InteractiveToolManager before the job handler starts
self.interactivetool_manager = InteractiveToolManager(self)
# Start the job manager
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy/app_unittest_utils/galaxy_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(self, config=None, **kwargs) -> None:
self.interactivetool_manager = Bunch(create_interactivetool=lambda *args, **kwargs: None)
self.is_job_handler = False
self.biotools_metadata_source = None
self.trs_proxy = Bunch()
set_thread_app(self)

def url_for(*args, **kwds):
Expand Down
24 changes: 4 additions & 20 deletions lib/galaxy/managers/landing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
)
from uuid import uuid4

import yaml
from pydantic import UUID4
from sqlalchemy import select

Expand All @@ -15,10 +14,7 @@
ObjectNotFound,
RequestParameterMissingException,
)
from galaxy.managers.workflows import (
WorkflowContentsManager,
WorkflowCreateOptions,
)
from galaxy.managers.workflows import WorkflowContentsManager
from galaxy.model import (
ToolLandingRequest as ToolLandingRequestModel,
WorkflowLandingRequest as WorkflowLandingRequestModel,
Expand Down Expand Up @@ -106,22 +102,10 @@ def _ensure_workflow(self, trans: ProvidesUserContext, request: WorkflowLandingR
if request.workflow_source_type == "trs_url" and isinstance(trans.app, StructuredApp):
# trans is always structured app except for unit test
assert request.workflow_source
trs_id, trs_version = request.workflow_source.rsplit("/", 1)
_, trs_id, trs_version = trans.app.trs_proxy.get_trs_id_and_version_from_trs_url(request.workflow_source)
workflow = self.workflow_contents_manager.get_workflow_by_trs_id_and_version(
self.sa_session, trs_id=trs_id, trs_version=trs_version, user_id=trans.user and trans.user.id
workflow = self.workflow_contents_manager.get_or_create_workflow_from_trs(
trans, trs_url=request.workflow_source
)
if not workflow:
data = trans.app.trs_proxy.get_version_from_trs_url(request.workflow_source)
as_dict = yaml.safe_load(data)
raw_workflow_description = self.workflow_contents_manager.normalize_workflow_format(trans, as_dict)
created_workflow = self.workflow_contents_manager.build_workflow_from_raw_description(
trans,
raw_workflow_description,
WorkflowCreateOptions(),
)
workflow = created_workflow.workflow
request.workflow_id = workflow.id
request.workflow_id = workflow.latest_workflow_id

def get_tool_landing_request(self, trans: ProvidesUserContext, uuid: UUID4) -> ToolLandingRequest:
request = self._get_claimed_tool_landing_request(trans, uuid)
Expand Down
96 changes: 71 additions & 25 deletions lib/galaxy/managers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
text_column_filter,
)
from galaxy.model.item_attrs import UsesAnnotations
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.invocation import InvocationCancellationUserRequest
from galaxy.schema.schema import WorkflowIndexQueryPayload
from galaxy.structured_app import MinimalManagerApp
Expand Down Expand Up @@ -139,6 +138,7 @@
attach_ordered_steps,
has_cycles,
)
from galaxy.workflow.trs_proxy import TrsProxy

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -598,8 +598,10 @@ def add_serializers(self):


class WorkflowContentsManager(UsesAnnotations):
def __init__(self, app: MinimalManagerApp):

def __init__(self, app: MinimalManagerApp, trs_proxy: TrsProxy):
self.app = app
self.trs_proxy = trs_proxy
self._resource_mapper_function = get_resource_mapper_function(app)

def ensure_raw_description(self, dict_or_raw_description):
Expand Down Expand Up @@ -814,18 +816,17 @@ def _workflow_from_raw_description(
workflow.license = data.get("license")
workflow.creator_metadata = data.get("creator")

if hasattr(workflow_state_resolution_options, "archive_source"):
if workflow_state_resolution_options.archive_source:
source_metadata = {}
if workflow_state_resolution_options.archive_source == "trs_tool":
source_metadata["trs_tool_id"] = workflow_state_resolution_options.trs_tool_id
source_metadata["trs_version_id"] = workflow_state_resolution_options.trs_version_id
source_metadata["trs_server"] = workflow_state_resolution_options.trs_server
source_metadata["trs_url"] = workflow_state_resolution_options.trs_url
elif not workflow_state_resolution_options.archive_source.startswith("file://"): # URL import
source_metadata["url"] = workflow_state_resolution_options.archive_source
workflow_state_resolution_options.archive_source = None # so trs_id is not set for subworkflows
workflow.source_metadata = source_metadata # type:ignore[assignment]
if getattr(workflow_state_resolution_options, "archive_source", None):
source_metadata = {}
if workflow_state_resolution_options.archive_source in ("trs_tool", "trs_url"):
source_metadata["trs_tool_id"] = workflow_state_resolution_options.trs_tool_id
source_metadata["trs_version_id"] = workflow_state_resolution_options.trs_version_id
source_metadata["trs_server"] = workflow_state_resolution_options.trs_server
source_metadata["trs_url"] = workflow_state_resolution_options.trs_url
elif not workflow_state_resolution_options.archive_source.startswith("file://"): # URL import
source_metadata["url"] = workflow_state_resolution_options.archive_source
workflow_state_resolution_options.archive_source = None # so trs_id is not set for subworkflows
workflow.source_metadata = source_metadata

# Assume no errors until we find a step that has some
workflow.has_errors = False
Expand Down Expand Up @@ -2016,9 +2017,54 @@ def get_all_tools(self, workflow):
tools.extend(self.get_all_tools(step.subworkflow))
return tools

def get_or_create_workflow_from_trs(
self,
trans: ProvidesUserContext,
trs_url: Optional[str],
trs_id: Optional[str] = None,
trs_version: Optional[str] = None,
trs_server: Optional[str] = None,
):
user_id = trans.user and trans.user.id
assert user_id, "Cannot create workflow for anonymous user"
if not trs_url:
assert trs_server and trs_id and trs_version, "trs_url or trs_server, trs_version and trs_id must be passed"
server = self.trs_proxy.get_server(trs_server)
trs_url = server.get_trs_url(trs_id, trs_version)
else:
_, trs_id, trs_version = self.trs_proxy.get_trs_id_and_version_from_trs_url(trs_url=trs_url)
assert trs_id and trs_version and trs_url

workflow = self.get_workflow_by_trs_id_and_version(trs_id=trs_id, trs_version=trs_version, user_id=user_id)
if not workflow:
workflow = self.create_workflow_from_trs_url(trans, trs_url, trs_server)
return workflow

def create_workflow_from_trs_url(
self, trans: ProvidesUserContext, trs_url: str, trs_server: Optional[str] = None
) -> StoredWorkflow:
_, trs_tool_id, trs_version_id = self.trs_proxy.get_trs_id_and_version_from_trs_url(trs_url=trs_url)
data = self.trs_proxy.get_version_from_trs_url(trs_url)
as_dict = yaml.safe_load(data)
raw_workflow_description = self.normalize_workflow_format(trans, as_dict)
created_workflow = self.build_workflow_from_raw_description(
trans,
raw_workflow_description,
WorkflowCreateOptions(
trs_tool_id=trs_tool_id,
trs_version_id=trs_version_id,
trs_url=trs_url,
trs_server=trs_server,
archive_source="trs_url",
),
)
return created_workflow.stored_workflow

def get_workflow_by_trs_id_and_version(
self, sa_session: galaxy_scoped_session, trs_id: str, trs_version: str, user_id: Optional[int] = None
) -> Optional[model.Workflow]:
self, trs_id: str, trs_version: str, user_id: Optional[int] = None
) -> Optional[model.StoredWorkflow]:
sa_session = self.app.model.session

def to_json(column, keys: List[str]):
assert sa_session.bind
if sa_session.bind.dialect.name == "postgresql":
Expand All @@ -2028,12 +2074,12 @@ def to_json(column, keys: List[str]):
return cast.astext
else:
for key in keys:
column = column.__getitem__(key)
column = func.json_extract(column, f"$.{key}")
return column

stmnt = (
select(model.Workflow)
.join(model.StoredWorkflow, model.Workflow.stored_workflow_id == model.StoredWorkflow.id)
select(model.StoredWorkflow)
.join(model.Workflow, model.Workflow.id == model.StoredWorkflow.latest_workflow_id)
.filter(
and_(
to_json(model.Workflow.source_metadata, ["trs_tool_id"]) == trs_id,
Expand All @@ -2043,7 +2089,7 @@ def to_json(column, keys: List[str]):
)
if user_id:
stmnt = stmnt.filter(
model.StoredWorkflow.user_id == user_id, model.StoredWorkflow.latest_workflow_id == model.Workflow.id
model.StoredWorkflow.user_id == user_id,
)
else:
stmnt = stmnt.filter(model.StoredWorkflow.importable == true())
Expand Down Expand Up @@ -2104,11 +2150,11 @@ class WorkflowCreateOptions(WorkflowStateResolutionOptions):
shed_tool_conf: Optional[str] = None

# for workflows imported by archive source
archive_source: Optional[str] = ""
trs_tool_id: str = ""
trs_version_id: str = ""
trs_server: str = ""
trs_url: str = ""
archive_source: Optional[str] = None
trs_tool_id: Optional[str] = None
trs_version_id: Optional[str] = None
trs_server: Optional[str] = None
trs_url: Optional[str] = None

@property
def is_importable(self):
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7817,7 +7817,7 @@ class Workflow(Base, Dictifiable, RepresentById):
reports_config: Mapped[Optional[bytes]] = mapped_column(JSONType)
creator_metadata: Mapped[Optional[bytes]] = mapped_column(JSONType)
license: Mapped[Optional[str]] = mapped_column(TEXT)
source_metadata: Mapped[Optional[bytes]] = mapped_column(JSONType)
source_metadata: Mapped[Optional[Dict[str, str]]] = mapped_column(JSONType)
uuid: Mapped[Optional[Union[UUID, str]]] = mapped_column(UUIDType)

steps = relationship(
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _workflow_to_dict(self, trans, stored):
"""
Converts a workflow to a dict of attributes suitable for exporting.
"""
workflow_contents_manager = workflows.WorkflowContentsManager(self.app)
workflow_contents_manager = workflows.WorkflowContentsManager(self.app, self.app.trs_proxy)
return workflow_contents_manager.workflow_to_dict(
trans,
stored,
Expand Down
39 changes: 13 additions & 26 deletions lib/galaxy/webapps/galaxy/api/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,14 @@ def create(self, trans: GalaxyWebTransaction, payload=None, **kwd):
payload["workflow"] = workflow_src
return self.__api_import_new_workflow(trans, payload, **kwd)
elif archive_source == "trs_tool":
server = None
trs_tool_id = None
trs_version_id = None
import_source = None
if "trs_url" in payload:
parts = self.app.trs_proxy.match_url(
payload["trs_url"], trans.app.config.fetch_url_allowlist_ips
)
if parts:
server = self.app.trs_proxy.server_from_url(parts["trs_base_url"])
trs_tool_id = parts["tool_id"]
trs_version_id = parts["version_id"]
payload["trs_tool_id"] = trs_tool_id
payload["trs_version_id"] = trs_version_id
else:
raise exceptions.RequestParameterInvalidException(f"Invalid TRS URL {payload['trs_url']}.")
else:
trs_server = payload.get("trs_server")
server = self.app.trs_proxy.get_server(trs_server)
trs_tool_id = payload.get("trs_tool_id")
trs_version_id = payload.get("trs_version_id")

archive_data = server.get_version_descriptor(trs_tool_id, trs_version_id)
workflow = self.workflow_contents_manager.get_or_create_workflow_from_trs(
trans,
trs_url=payload.get("trs_url"),
trs_id=payload.get("trs_tool_id"),
trs_version=payload.get("trs_version_id"),
trs_server=payload.get("trs_server"),
)
return self.__api_import_response(workflow)
else:
try:
archive_data = stream_url_to_str(
Expand Down Expand Up @@ -602,13 +587,15 @@ def __api_import_from_archive(self, trans: GalaxyWebTransaction, archive_data, s
workflow, missing_tool_tups = self._workflow_from_dict(
trans, raw_workflow_description, workflow_create_options, source=source
)
workflow_id = workflow.id
workflow = workflow.latest_workflow
return self.__api_import_response(workflow)

def __api_import_response(self, stored_workflow: model.StoredWorkflow):
workflow = stored_workflow.latest_workflow
assert workflow
response = {
"message": f"Workflow '{workflow.name}' imported successfully.",
"status": "success",
"id": trans.security.encode_id(workflow_id),
"id": self.app.security.encode_id(stored_workflow.id),
}
if workflow.has_errors:
response["message"] = "Imported, but some steps in this workflow have validation errors."
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def get_all_outputs(self, data_only=False):
if hasattr(self.subworkflow, "workflow_outputs"):
from galaxy.managers.workflows import WorkflowContentsManager

workflow_contents_manager = WorkflowContentsManager(self.trans.app)
workflow_contents_manager = WorkflowContentsManager(self.trans.app, self.trans.app.trs_proxy)
subworkflow_dict = workflow_contents_manager._workflow_to_dict_editor(
trans=self.trans,
stored=self.subworkflow.stored_workflow,
Expand Down
3 changes: 3 additions & 0 deletions lib/galaxy/workflow/trs_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def get_version_descriptor(self, tool_id, version_id, **kwd):
)
return self._get(trs_api_url)["content"]

def get_trs_url(self, tool_id: str, version_id: str):
return f"{self._get_tool_api_endpoint(tool_id)}/versions/{version_id}"

def _quote(self, tool_id, **kwd):
if asbool(kwd.get("tool_id_b64_encoded", False)):
import base64
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy_test/api/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ def test_trs_import_from_dockstore_trs_url(self):
== "#workflow/github.com/jmchilton/galaxy-workflow-dockstore-example-1/mycoolworkflow"
)
assert original_workflow.get("source_metadata").get("trs_version_id") == "master"
assert original_workflow.get("source_metadata").get("trs_server") == ""
assert not original_workflow.get("source_metadata").get("trs_server")
assert original_workflow.get("source_metadata").get("trs_url") == (
"https://dockstore.org/api/ga4gh/trs/v2/tools/"
"%23workflow%2Fgithub.com%2Fjmchilton%2Fgalaxy-workflow-dockstore-example-1%2Fmycoolworkflow/"
Expand Down Expand Up @@ -1264,7 +1264,7 @@ def test_trs_import_from_workflowhub_trs_url(self):
assert "COVID-19: variation analysis reporting" in original_workflow["name"]
assert original_workflow.get("source_metadata").get("trs_tool_id") == "109"
assert original_workflow.get("source_metadata").get("trs_version_id") == "5"
assert original_workflow.get("source_metadata").get("trs_server") == ""
assert not original_workflow.get("source_metadata").get("trs_server")
assert (
original_workflow.get("source_metadata").get("trs_url")
== "https://workflowhub.eu/ga4gh/trs/v2/tools/109/versions/5"
Expand Down
2 changes: 1 addition & 1 deletion test/unit/app/managers/test_landing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TestLanding(BaseTestCase):

def setUp(self):
super().setUp()
self.workflow_contents_manager = WorkflowContentsManager(self.app)
self.workflow_contents_manager = WorkflowContentsManager(self.app, self.app.trs_proxy)
self.landing_manager = LandingRequestManager(
self.trans.sa_session, self.app.security, self.workflow_contents_manager
)
Expand Down
4 changes: 3 additions & 1 deletion test/unit/workflows/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ def _output_step(step_input_def, step_output_def) -> Dict[str, Any]:
@pytest.mark.parametrize("test_case", _construct_steps_for_map_over())
def test_subworkflow_map_over_type(test_case):
trans = MockTrans()
new_steps = WorkflowContentsManager(app=trans.app)._resolve_collection_type(test_case.steps)
new_steps = WorkflowContentsManager(app=trans.app, trs_proxy=trans.app.trs_proxy)._resolve_collection_type(
test_case.steps
)
assert (
new_steps[1]["outputs"][0].get("collection_type") == test_case.expected_collection_type
), "Expected collection_type '{}' for a '{}' input module, a '{}' input and a '{}' output, got collection_type '{}' instead".format(
Expand Down
2 changes: 2 additions & 0 deletions test/unit/workflows/workflow_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from galaxy.app_unittest_utils import galaxy_mock
from galaxy.managers.workflows import WorkflowsManager
from galaxy.model.base import transaction
from galaxy.util.bunch import Bunch
from galaxy.workflow.modules import module_factory


class MockTrans:
def __init__(self):
self.app = MockApp()
self.app.trs_proxy = Bunch()
self.sa_session = self.app.model.context
self._user = None

Expand Down
Loading