Skip to content

Commit

Permalink
Improved tool API.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Jul 10, 2024
1 parent ad8d26a commit ec8824f
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 84 deletions.
6 changes: 3 additions & 3 deletions lib/galaxy/config/schemas/tool_shed_config_schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ mapping:
the repositories and tools within the Tool Shed given that you specify
the following two config options.
tool_state_cache_dir:
model_cache_dir:
type: str
default: database/tool_state_cache
default: database/model_cache
required: false
desc: |
Cache directory for tool state.
Cache directory for Pydantic model objects.
repo_name_boost:
type: float
Expand Down
64 changes: 64 additions & 0 deletions lib/tool_shed/managers/model_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
import os
from typing import (
Any,
Dict,
Optional,
Type,
TypeVar,
)

from pydantic import BaseModel

from galaxy.util.hash_util import md5_hash_str

RAW_CACHED_JSON = Dict[str, Any]


def hash_model(model_class: Type[BaseModel]) -> str:
return md5_hash_str(json.dumps(model_class.model_json_schema()))


MODEL_HASHES: Dict[Type[BaseModel], str] = {}


M = TypeVar("M", bound=BaseModel)


def ensure_model_has_hash(model_class: Type[BaseModel]) -> None:
if model_class not in MODEL_HASHES:
MODEL_HASHES[model_class] = hash_model(model_class)


class ModelCache:
_cache_directory: str

def __init__(self, cache_directory: str):
if not os.path.exists(cache_directory):
os.makedirs(cache_directory)
self._cache_directory = cache_directory

def _cache_target(self, model_class: Type[M], tool_id: str, tool_version: str) -> str:
ensure_model_has_hash(model_class)
# consider breaking this into multiple directories...
cache_target = os.path.join(self._cache_directory, MODEL_HASHES[model_class], tool_id, tool_version)
return cache_target

def get_cache_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> Optional[M]:
cache_target = self._cache_target(model_class, tool_id, tool_version)
if not os.path.exists(cache_target):
return None
with open(cache_target) as f:
return model_class.model_validate(json.load(f))

def has_cached_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> bool:
cache_target = self._cache_target(model_class, tool_id, tool_version)
return os.path.exists(cache_target)

def insert_cache_entry_for(self, model_object: M, tool_id: str, tool_version: str) -> None:
cache_target = self._cache_target(model_object.__class__, tool_id, tool_version)
parent_directory = os.path.dirname(cache_target)
if not os.path.exists(parent_directory):
os.makedirs(parent_directory)
with open(cache_target, "w") as f:
json.dump(model_object.dict(), f)
42 changes: 0 additions & 42 deletions lib/tool_shed/managers/tool_state_cache.py

This file was deleted.

80 changes: 66 additions & 14 deletions lib/tool_shed/managers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Tuple,
)

from pydantic import BaseModel

from galaxy import exceptions
from galaxy.exceptions import (
InternalServerError,
Expand All @@ -21,13 +23,16 @@
)
from galaxy.tool_util.parameters import (
input_models_for_tool_source,
tool_parameter_bundle_from_json,
ToolParameterBundleModel,
ToolParameterT,
)
from galaxy.tool_util.parser import (
get_tool_source,
ToolSource,
)
from galaxy.tool_util.parser.interface import (
Citation,
XrefDict,
)
from galaxy.tools.stock import stock_tool_sources
from tool_shed.context import (
ProvidesRepositoriesContext,
Expand All @@ -41,6 +46,53 @@
STOCK_TOOL_SOURCES: Optional[Dict[str, Dict[str, ToolSource]]] = None


# parse the tool source with galaxy.util abstractions to provide a bit richer
# information about the tool than older tool shed abstractions.
class ParsedTool(BaseModel):
id: str
version: Optional[str]
name: str
description: Optional[str]
inputs: List[ToolParameterT]
citations: List[Citation]
license: Optional[str]
profile: Optional[str]
edam_operations: List[str]
edam_topics: List[str]
xrefs: List[XrefDict]
help: Optional[str]


def _parse_tool(tool_source: ToolSource) -> ParsedTool:
id = tool_source.parse_id()
version = tool_source.parse_version()
name = tool_source.parse_name()
description = tool_source.parse_description()
inputs = input_models_for_tool_source(tool_source).input_models
citations = tool_source.parse_citations()
license = tool_source.parse_license()
profile = tool_source.parse_profile()
edam_operations = tool_source.parse_edam_operations()
edam_topics = tool_source.parse_edam_topics()
xrefs = tool_source.parse_xrefs()
help = tool_source.parse_help()

return ParsedTool(
id=id,
version=version,
name=name,
description=description,
profile=profile,
inputs=inputs,
license=license,
citations=citations,
edam_operations=edam_operations,
edam_topics=edam_topics,
xrefs=xrefs,
help=help,
)


def search(trans: SessionRequestContext, q: str, page: int = 1, page_size: int = 10) -> dict:
"""
Perform the search over TS tools index.
Expand Down Expand Up @@ -97,23 +149,23 @@ def get_repository_metadata_tool_dict(
raise ObjectNotFound()


def tool_input_models_cached_for(
def parsed_tool_model_cached_for(
trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None
) -> ToolParameterBundleModel:
tool_state_cache = trans.app.tool_state_cache
raw_json = tool_state_cache.get_cache_entry_for(trs_tool_id, tool_version)
if raw_json is not None:
return tool_parameter_bundle_from_json(raw_json)
bundle = tool_input_models_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url)
tool_state_cache.insert_cache_entry_for(trs_tool_id, tool_version, bundle.dict())
return bundle
) -> ParsedTool:
model_cache = trans.app.model_cache
parsed_tool = model_cache.get_cache_entry_for(ParsedTool, trs_tool_id, tool_version)
if parsed_tool is not None:
return parsed_tool
parsed_tool = parsed_tool_model_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url)
model_cache.insert_cache_entry_for(parsed_tool, trs_tool_id, tool_version)
return parsed_tool


def tool_input_models_for(
def parsed_tool_model_for(
trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None
) -> ToolParameterBundleModel:
) -> ParsedTool:
tool_source = tool_source_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url)
return input_models_for_tool_source(tool_source)
return _parse_tool(tool_source)


def tool_source_for(
Expand Down
4 changes: 2 additions & 2 deletions lib/tool_shed/structured_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from galaxy.structured_app import BasicSharedApp

if TYPE_CHECKING:
from tool_shed.managers.tool_state_cache import ToolStateCache
from tool_shed.managers.model_cache import ModelCache
from tool_shed.repository_registry import Registry as RepositoryRegistry
from tool_shed.repository_types.registry import Registry as RepositoryTypesRegistry
from tool_shed.util.hgweb_config import HgWebConfigManager
Expand All @@ -17,4 +17,4 @@ class ToolShedApp(BasicSharedApp):
repository_registry: "RepositoryRegistry"
hgweb_config_manager: "HgWebConfigManager"
security_agent: "CommunityRBACAgent"
tool_state_cache: "ToolStateCache"
model_cache: "ModelCache"
17 changes: 8 additions & 9 deletions lib/tool_shed/webapp/api2/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from galaxy.tool_util.parameters import (
RequestToolState,
to_json_schema_string,
ToolParameterBundleModel,
)
from tool_shed.context import SessionRequestContext
from tool_shed.managers.tools import (
parsed_tool_model_cached_for,
ParsedTool,
search,
tool_input_models_cached_for,
)
from tool_shed.managers.trs import (
get_tool,
Expand Down Expand Up @@ -144,17 +144,17 @@ def trs_get_versions(
return get_tool(trans, tool_id).versions

@router.get(
"/api/tools/{tool_id}/versions/{tool_version}/parameter_model",
"/api/tools/{tool_id}/versions/{tool_version}",
operation_id="tools__parameter_model",
summary="Return Galaxy's meta model description of the tool's inputs",
)
def tool_parameters_meta_model(
def show_tool(
self,
trans: SessionRequestContext = DependsOnTrans,
tool_id: str = TOOL_ID_PATH_PARAM,
tool_version: str = TOOL_VERSION_PATH_PARAM,
) -> ToolParameterBundleModel:
return tool_input_models_cached_for(trans, tool_id, tool_version)
) -> ParsedTool:
return parsed_tool_model_cached_for(trans, tool_id, tool_version)

@router.get(
"/api/tools/{tool_id}/versions/{tool_version}/parameter_request_schema",
Expand All @@ -168,6 +168,5 @@ def tool_state(
tool_id: str = TOOL_ID_PATH_PARAM,
tool_version: str = TOOL_VERSION_PATH_PARAM,
) -> Response:
return json_schema_response(
RequestToolState.parameter_model_for(tool_input_models_cached_for(trans, tool_id, tool_version))
)
parsed_tool = parsed_tool_model_cached_for(trans, tool_id, tool_version)
return json_schema_response(RequestToolState.parameter_model_for(parsed_tool.inputs))
4 changes: 2 additions & 2 deletions lib/tool_shed/webapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from galaxy.structured_app import BasicSharedApp
from galaxy.web_stack import application_stack_instance
from tool_shed.grids.repository_grid_filter_manager import RepositoryGridFilterManager
from tool_shed.managers.tool_state_cache import ToolStateCache
from tool_shed.managers.model_cache import ModelCache
from tool_shed.structured_app import ToolShedApp
from tool_shed.util.hgweb_config import hgweb_config_manager
from tool_shed.webapp.model.migrations import verify_database
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, **kwd) -> None:
self._register_singleton(SharedModelMapping, model)
self._register_singleton(mapping.ToolShedModelMapping, model)
self._register_singleton(scoped_session, self.model.context)
self.tool_state_cache = ToolStateCache(self.config.tool_state_cache_dir)
self.model_cache = ModelCache(self.config.model_cache_dir)
self.user_manager = self._register_singleton(UserManager, UserManager(self, app_type="tool_shed"))
self.api_keys_manager = self._register_singleton(ApiKeyManager)
# initialize the Tool Shed tag handler.
Expand Down
4 changes: 2 additions & 2 deletions test/unit/tool_shed/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from galaxy.security.idencoding import IdEncodingHelper
from galaxy.util import safe_makedirs
from tool_shed.context import ProvidesRepositoriesContext
from tool_shed.managers.model_cache import ModelCache
from tool_shed.managers.repositories import upload_tar_and_set_metadata
from tool_shed.managers.tool_state_cache import ToolStateCache
from tool_shed.managers.users import create_user
from tool_shed.repository_types import util as rt_util
from tool_shed.repository_types.registry import Registry as RepositoryTypesRegistry
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self, temp_directory=None):
self.config = TestToolShedConfig(temp_directory)
self.security = IdEncodingHelper(id_secret=self.config.id_secret)
self.repository_registry = tool_shed.repository_registry.Registry(self)
self.tool_state_cache = ToolStateCache(os.path.join(temp_directory, "tool_state_cache"))
self.model_cache = ModelCache(os.path.join(temp_directory, "model_cache"))

@property
def security_agent(self):
Expand Down
53 changes: 53 additions & 0 deletions test/unit/tool_shed/test_model_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pydantic import (
BaseModel,
ConfigDict,
)

from tool_shed.managers.model_cache import (
hash_model,
ModelCache,
)


class Moo(BaseModel):
foo: int


class MooLike(BaseModel):
model_config = ConfigDict(title="Moo")
foo: int


class NewMoo(BaseModel):
model_config = ConfigDict(title="Moo")
foo: int
new_prop: str


def test_hash():
hash_moo_1 = hash_model(Moo)
hash_moo_2 = hash_model(Moo)
assert hash_moo_1 == hash_moo_2


def test_hash_by_value():
hash_moo_1 = hash_model(Moo)
hash_moo_like = hash_model(MooLike)
assert hash_moo_1 == hash_moo_like


def test_hash_different_on_updates():
hash_moo_1 = hash_model(Moo)
hash_moo_new = hash_model(NewMoo)
assert hash_moo_1 != hash_moo_new


def cache_dict(tmp_path):
model_cache = ModelCache(tmp_path)
assert not model_cache.has_cached_entry_for(Moo, "moo", "1.0")
assert None is model_cache.get_cache_entry_for(Moo, "moo", "1.0")
model_cache.insert_cache_entry_for(Moo(foo=4), "moo", "1.0")
moo = model_cache.get_cache_entry_for(Moo, "moo", "1.0")
assert moo
assert moo.foo == 4
assert model_cache.has_cached_entry_for(Moo, "moo", "1.0")
Loading

0 comments on commit ec8824f

Please sign in to comment.