diff --git a/CHANGELOG.md b/CHANGELOG.md index 764c3ae..d5df052 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Current (in progress) -- Nothing yet +- feat: handle multiple tags filter on datasets [#50](https://github.com/opendatateam/udata-search-service/pull/50) ## 2.2.0 (2024-11-07) @@ -10,7 +10,7 @@ - Add dataservices search [#48](https://github.com/opendatateam/udata-search-service/pull/48) :warning: To use these new features, you will need to: - - init ES indices with `udata-search-service init-es` + - init ES indices with `udata-search-service init-es` - index datasets and dataservices on udata side with `udata search index dataset` and `udata search index dataservice` ## 2.1.0 (2024-10-07) diff --git a/pyproject.toml b/pyproject.toml index c8f3cb1..cc861ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ dependencies = [ "pytest-flask==1.2.0", "markdown==3.3.3", "beautifulsoup4==4.10.0", + # pinned to a known working version, since 3.1.3 will fail while running tests + "werkzeug==3.0.4", ] [project.urls] diff --git a/tests/test_api.py b/tests/test_api.py index 9d36c59..2b9fbfe 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -678,7 +678,7 @@ def test_api_dataset_search_with_temporal_filter(app, client, search_client, fak def test_api_search_with_tag_filter(app, client, search_client, faker): for i in range(4): - search_client.index_dataset(DatasetFactory(tags=['test-tag'] if i % 2 else ['not-test-tag'])) + search_client.index_dataset(DatasetFactory(tags=['test-tag', 'test-tag-2'] if i % 2 else ['not-test-tag'])) search_client.index_reuse(ReuseFactory(tags=['test-tag'] if i % 2 else ['not-test-tag'])) # Without this, ElasticSearch does not seem to have the time to index. @@ -688,7 +688,10 @@ def test_api_search_with_tag_filter(app, client, search_client, faker): resp = client.get(url_for('api.dataset_search')) assert resp.json['total'] == 4 - resp = client.get(url_for('api.dataset_search', tag='test-tag')) + resp = client.get(url_for('api.dataset_search', tag=['test-tag', 'test-tag-2'])) + assert resp.json['total'] == 2 + + resp = client.get(url_for('api.dataset_search', tag=['not-test-tag'])) assert resp.json['total'] == 2 resp = client.get(url_for('api.reuse_search')) diff --git a/tests/test_search_client.py b/tests/test_search_client.py index 2f5dc37..09edcc8 100644 --- a/tests/test_search_client.py +++ b/tests/test_search_client.py @@ -190,7 +190,7 @@ def test_search_datasets_with_temporal_filters(app, client, search_client, faker def test_search_with_tag_filter(app, client, search_client, faker): for i in range(4): search_client.index_dataset(DatasetFactory( - tags=['test-tag'] if i % 2 else ['not-test-tag'] + tags=['test-tag', f'test-tag-{i}'] if i % 2 else ['not-test-tag'] )) search_client.index_reuse(ReuseFactory( @@ -206,7 +206,9 @@ def test_search_with_tag_filter(app, client, search_client, faker): results_number, res = search_client.query_datasets(None, 0, 20, {}) assert results_number == 4 - results_number, res = search_client.query_datasets(None, 0, 20, {'tags': 'test-tag'}) + results_number, res = search_client.query_datasets(None, 0, 20, {'tags': ['test-tag', 'test-tag-1']}) + assert results_number == 1 + results_number, res = search_client.query_datasets(None, 0, 20, {'tags': ['not-test-tag']}) assert results_number == 2 results_number, res = search_client.query_reuses(None, 0, 20, {}) assert results_number == 4 diff --git a/udata_search_service/infrastructure/search_clients.py b/udata_search_service/infrastructure/search_clients.py index 162c5db..f8eed2d 100644 --- a/udata_search_service/infrastructure/search_clients.py +++ b/udata_search_service/infrastructure/search_clients.py @@ -258,6 +258,12 @@ def query_datasets(self, query_text: str, offset: int, page_size: int, filters: search = search.filter('range', **{'temporal_coverage_start': {'lte': value}}) elif key == 'temporal_coverage_end': search = search.filter('range', **{'temporal_coverage_end': {'gte': value}}) + elif key == 'tags': + # build an AND filter from tags list + tag_filters = [query.Q('term', tags=tag) for tag in value] + search = search.filter( + query.Bool(must=tag_filters) + ) else: search = search.filter('term', **{key: value}) diff --git a/udata_search_service/presentation/api.py b/udata_search_service/presentation/api.py index 0d2c02f..7e118cf 100644 --- a/udata_search_service/presentation/api.py +++ b/udata_search_service/presentation/api.py @@ -10,6 +10,7 @@ from udata_search_service.infrastructure.search_clients import ElasticClient from udata_search_service.infrastructure.consumers import DatasetConsumer, ReuseConsumer, OrganizationConsumer, DataserviceConsumer from udata_search_service.infrastructure.migrate import set_alias as set_alias_func +from udata_search_service.presentation.utils import is_list_type bp = Blueprint('api', __name__, url_prefix='/api/1') @@ -38,7 +39,7 @@ class DatasetArgs(BaseModel): page: Optional[int] = 1 page_size: Optional[int] = 20 sort: Optional[str] = None - tag: Optional[str] = None + tag: Optional[list[str]] = None badge: Optional[str] = None organization: Optional[str] = None organization_badge: Optional[str] = None @@ -69,6 +70,22 @@ def sort_validate(cls, value): raise ValueError('Temporal coverage does not match the right pattern.') return value + @classmethod + def from_request_args(cls, request_args) -> 'DatasetArgs': + def get_list_args() -> dict: + return { + key: value + for key, value in request_args.to_dict(flat=False).items() + if key in cls.__fields__ + and is_list_type(cls.__fields__[key].outer_type_) + } + + return cls( + **{ + **request_args.to_dict(), + **get_list_args(), + } + ) class ReuseArgs(BaseModel): q: Optional[str] = None @@ -192,7 +209,7 @@ def dataset_unindex(dataset_id: str, dataset_service: DatasetService = Provide[C @inject def datasets_search(dataset_service: DatasetService = Provide[Container.dataset_service]): try: - request_args = DatasetArgs(**request.args) + request_args = DatasetArgs.from_request_args(request.args) except ValidationError as e: abort(400, e) diff --git a/udata_search_service/presentation/utils.py b/udata_search_service/presentation/utils.py new file mode 100644 index 0000000..1937257 --- /dev/null +++ b/udata_search_service/presentation/utils.py @@ -0,0 +1,10 @@ +from typing import Type, get_origin, Union, get_args + + +def is_list_type(type_: Type) -> bool: + origin = get_origin(type_) + if origin is list: + return True + if origin is Union: + return any(is_list_type(t) for t in get_args(type_)) + return False