diff --git a/pyproject.toml b/pyproject.toml index dcb5694..b1b0766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "geneweaver-db" -version = "0.5.0a15" +version = "0.5.0a16" description = "Database Interaction Services for GeneWeaver" authors = ["Jax Computational Sciences "] readme = "README.md" diff --git a/src/geneweaver/db/aio/geneset.py b/src/geneweaver/db/aio/geneset.py index 20ed262..3e4a301 100644 --- a/src/geneweaver/db/aio/geneset.py +++ b/src/geneweaver/db/aio/geneset.py @@ -1,5 +1,6 @@ """Async database interaction code relating to Genesets.""" +from datetime import date from typing import List, Optional from geneweaver.core.enum import GeneIdentifier, GenesetTier, ScoreType, Species @@ -25,12 +26,18 @@ async def get( pubmed_id: Optional[int] = None, gene_id_type: Optional[GeneIdentifier] = None, search_text: Optional[str] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, is_readable_by: Optional[int] = None, with_publication_info: bool = True, ontology_term: Optional[str] = None, score_type: Optional[ScoreType] = None, + lte_count: Optional[int] = None, + gte_count: Optional[int] = None, + created_after: Optional[date] = None, + created_before: Optional[date] = None, + updated_after: Optional[date] = None, + updated_before: Optional[date] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, ) -> List[Row]: """Get genesets from the database. @@ -46,12 +53,18 @@ async def get( :param gene_id_type: Show only results with this gene ID type. :param search_text: Return genesets that match this search text (using PostgreSQL full-text search). - :param limit: Limit the number of results. - :param offset: Offset the results. :param is_readable_by: A user ID to check if the user can read the results. :param with_publication_info: Include publication info in the return. :param ontology_term: Show only results associated with this ontology term. :param score_type: Show only results with given score type. + :param lte_count: less than or equal count. + :param gte_count: greater than or equal count. + :param updated_before: Show only results updated before this date. + :param updated_after: Show only results updated after this date. + :param created_before: Show only results created before this date. + :param created_after: Show only results updated before this date. + :param limit: Limit the number of results. + :param offset: Offset the results. :return: list of results using `.fetchall()` """ @@ -73,6 +86,12 @@ async def get( with_publication_info=with_publication_info, ontology_term=ontology_term, score_type=score_type, + lte_count=lte_count, + gte_count=gte_count, + created_after=created_after, + created_before=created_before, + updated_after=updated_after, + updated_before=updated_before, ) ) diff --git a/src/geneweaver/db/geneset.py b/src/geneweaver/db/geneset.py index 0d8d96b..35a3e76 100644 --- a/src/geneweaver/db/geneset.py +++ b/src/geneweaver/db/geneset.py @@ -1,5 +1,6 @@ """Geneset database functions.""" +from datetime import date from typing import List, Optional from geneweaver.core.enum import GeneIdentifier, GenesetTier, ScoreType, Species @@ -31,6 +32,12 @@ def get( with_publication_info: bool = True, ontology_term: Optional[str] = None, score_type: Optional[ScoreType] = None, + lte_count: Optional[int] = None, + gte_count: Optional[int] = None, + created_after: Optional[date] = None, + created_before: Optional[date] = None, + updated_after: Optional[date] = None, + updated_before: Optional[date] = None, ) -> List[Row]: """Get genesets from the database. @@ -47,11 +54,17 @@ def get( :param gene_id_type: Show only results with this gene ID type. :param search_text: Return genesets that match this search text (using PostgreSQL full-text search). - :param limit: Limit the number of results. - :param offset: Offset the results. :param with_publication_info: Include publication info in the return. :param ontology_term: Show only results associated with this ontology term. :param score_type: Show only results with given score type. + :param lte_count: less than or equal count. + :param gte_count: greater than or equal count. + :param updated_before: Show only results updated before this date. + :param updated_after: Show only results updated after this date. + :param created_before: Show only results created before this date. + :param created_after: Show only results updated before this date. + :param limit: Limit the number of results. + :param offset: Offset the results. :return: list of results using `.fetchall()` """ @@ -73,6 +86,12 @@ def get( with_publication_info=with_publication_info, ontology_term=ontology_term, score_type=score_type, + lte_count=lte_count, + gte_count=gte_count, + created_after=created_after, + created_before=created_before, + updated_after=updated_after, + updated_before=updated_before, ) ) diff --git a/src/geneweaver/db/query/geneset/read.py b/src/geneweaver/db/query/geneset/read.py index adff785..c43935c 100644 --- a/src/geneweaver/db/query/geneset/read.py +++ b/src/geneweaver/db/query/geneset/read.py @@ -1,5 +1,6 @@ """SQL query generation code for reading genesets.""" +from datetime import date from typing import Optional, Tuple from geneweaver.core.enum import GeneIdentifier, ScoreType, Species @@ -11,7 +12,7 @@ restrict_tier, search, ) -from geneweaver.db.query.utils import construct_filters +from geneweaver.db.query.utils import construct_filters, construct_op_filters from geneweaver.db.utils import GenesetTierOrTiers, limit_and_offset from psycopg.sql import SQL, Composed @@ -34,6 +35,12 @@ def get( with_publication_info: bool = True, ontology_term: Optional[str] = None, score_type: Optional[ScoreType] = None, + lte_count: Optional[int] = None, + gte_count: Optional[int] = None, + created_after: Optional[date] = None, + created_before: Optional[date] = None, + updated_after: Optional[date] = None, + updated_before: Optional[date] = None, ) -> Tuple[Composed, dict]: """Get genesets. @@ -55,6 +62,12 @@ def get( :param with_publication_info: Include publication info in the return. :param ontology_term: Show only results associated with this ontology term. :param score_type: Show only results with given score type. + :param lte_count: less than or equal count. + :param gte_count: greater than or equal count. + :param updated_before: Show only results updated before this date. + :param updated_after: Show only results updated after this date. + :param created_before: Show only results created before this date. + :param created_after: Show only results updated before this date. """ params = {} filtering = [] @@ -93,6 +106,49 @@ def get( }, ) + filtering, params = construct_op_filters( + filters=filtering, + params=params, + filter_items=[ + { + "field": "gs_count", + "value": lte_count, + "op": "<=", + "place_holder": "count_less_than", + }, + { + "field": "gs_count", + "value": gte_count, + "op": ">=", + "place_holder": "count_greater_than", + }, + { + "field": "gs_created", + "value": created_before, + "op": "<=", + "place_holder": "created_before", + }, + { + "field": "gs_created", + "value": created_after, + "op": ">=", + "place_holder": "created_after", + }, + { + "field": "gs_updated", + "value": updated_before, + "op": "<=", + "place_holder": "updated_before", + }, + { + "field": "gs_updated", + "value": updated_after, + "op": ">=", + "place_holder": "updated_after", + }, + ], + ) + if len(filtering) > 0: query += SQL("WHERE") + SQL("AND").join(filtering) diff --git a/src/geneweaver/db/query/utils.py b/src/geneweaver/db/query/utils.py index a0b198f..3c6c36f 100644 --- a/src/geneweaver/db/query/utils.py +++ b/src/geneweaver/db/query/utils.py @@ -8,6 +8,7 @@ SQLList = List[Union[Composed, SQL]] ParamDict = Dict[str, Union[str, int, list]] OptionalParamDict = Dict[str, Optional[Union[str, int]]] +OptionalParamTuple = Tuple[str, Optional[Union[str, int]]] def construct_filter( @@ -15,6 +16,8 @@ def construct_filter( params: ParamDict, filter_name: str, filter_value: Optional[Union[str, int]], + operator: Optional[str] = None, + place_holder: Optional[str] = None, ) -> Tuple[SQLList, ParamDict]: """Construct a simple filter for a query. @@ -22,16 +25,28 @@ def construct_filter( :param params: The existing parameters. :param filter_name: The filter name to construct. :param filter_value: The filter value to construct. + :param operator: sql operator + :param place_holder: parameter placeholder name :return: The constructed filters and parameters. + """ + if place_holder is None: + place_holder = filter_name + if filter_value is not None: + if operator: + filter_str = "{filter_name} " + operator + " {param_name}" + else: + filter_str = "{filter_name} = {param_name}" + filters.append( - SQL("{filter_name} = {param_name}").format( - filter_name=Identifier(filter_name), param_name=Placeholder(filter_name) + SQL(filter_str).format( + filter_name=Identifier(filter_name), + param_name=Placeholder(place_holder), ) ) - params[filter_name] = filter_value + params[place_holder] = filter_value return filters, params @@ -61,6 +76,36 @@ def construct_filters( return filters, params +def construct_op_filters( + filters: SQLList, + params: ParamDict, + filter_items: [dict], +) -> Tuple[SQLList, ParamDict]: + """Construct multiple simple filters with operators for a query. + + Calls the `construct_filter` function for each filter item. + + :param filters: The existing filters. + :param params: The existing parameters. + :param filter_items: The filter items to construct. + + :return: The constructed filters and parameters. + """ + for filter_item in filter_items: + filters, params = construct_filter( + filters=filters, + params=params, + filter_name=filter_item.get("field"), + filter_value=filter_item.get("value"), + operator=filter_item.get("op"), + place_holder=filter_item.get("place_holder"), + ) + + print(filters) + print(params) + return filters, params + + def search( field_ts_vector: str, existing_filters: SQLList, diff --git a/tests/unit/geneset/test_get.py b/tests/unit/geneset/test_get.py index 86e5f1e..20c2ae5 100644 --- a/tests/unit/geneset/test_get.py +++ b/tests/unit/geneset/test_get.py @@ -1,5 +1,8 @@ """Test the general geneset.get function.""" +import datetime + +import pytest from geneweaver.db.aio.geneset import get as async_get from geneweaver.db.geneset import get @@ -25,3 +28,39 @@ test_async_get_fetchall_raises_error = async_create_fetchall_raises_error_test( async_get, ) + + +@pytest.mark.parametrize("created_after", [None, datetime.datetime(2008, 7, 31)]) +@pytest.mark.parametrize("created_before", [None, datetime.datetime(2024, 7, 31)]) +def test_get_gs_by_create_date(created_before, created_after, example_genesets, cursor): + """Test the geneset.get function by create date using a mock cursor.""" + cursor.fetchall.return_value = example_genesets + result = get(cursor, created_before=created_before, created_after=created_after) + assert result == example_genesets + assert cursor.execute.call_count == 1 + assert cursor.fetchone.call_count == 0 + assert cursor.fetchall.call_count == 1 + + +@pytest.mark.parametrize("updated_after", [None, datetime.datetime(2008, 7, 31)]) +@pytest.mark.parametrize("updated_before", [None, datetime.datetime(2024, 7, 31)]) +def test_get_gs_by_update_date(updated_before, updated_after, example_genesets, cursor): + """Test the geneset.get function by update date a mock cursor.""" + cursor.fetchall.return_value = example_genesets + result = get(cursor, updated_before=updated_before, updated_after=updated_after) + assert result == example_genesets + assert cursor.execute.call_count == 1 + assert cursor.fetchone.call_count == 0 + assert cursor.fetchall.call_count == 1 + + +@pytest.mark.parametrize("lte_count", [None, 20]) +@pytest.mark.parametrize("gte_count", [None, 5]) +def test_get_gs_gense_count_size(lte_count, gte_count, example_genesets, cursor): + """Test the geneset.get function by update date a mock cursor.""" + cursor.fetchall.return_value = example_genesets + result = get(cursor, lte_count=lte_count, gte_count=gte_count) + assert result == example_genesets + assert cursor.execute.call_count == 1 + assert cursor.fetchone.call_count == 0 + assert cursor.fetchall.call_count == 1