diff --git a/nmdc_server/query.py b/nmdc_server/query.py index 2665a2dc..eee0f394 100644 --- a/nmdc_server/query.py +++ b/nmdc_server/query.py @@ -653,6 +653,18 @@ def _inject_omics_data_summary(self, db: Session, query: Query) -> Query: ), ) + def query(self, db: Session): + study_query = super().query(db) + if any([condition.table == Table.biosample for condition in self.conditions]): + sample_query = BiosampleQuerySchema(conditions=self.conditions).query(db) + studies_from_sample_query = sample_query.with_entities( + models.Biosample.study_id + ).distinct() + study_query = study_query.where( # type: ignore + self.table.model.id.in_(studies_from_sample_query) # type: ignore + ) + return study_query + def execute(self, db: Session) -> Query: sample_subquery = BiosampleQuerySchema(conditions=self.conditions).query(db).subquery() sample_count = ( diff --git a/tests/test_query.py b/tests/test_query.py index 3a2157f2..23f76e01 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -99,6 +99,31 @@ def test_basic_query(db: Session, table): assert tests[table][0].id in {r.id for r in q.all()} +def test_study_search_biosample_conditions(db: Session): + test_study = fakes.StudyFactory() + _ = fakes.BiosampleFactory(longitude=10, latitude=0, study=test_study) + _ = fakes.BiosampleFactory(longitude=0, latitude=50, study=test_study) + sample_3 = fakes.BiosampleFactory(longitude=10, latitude=50) + db.commit() + + condition_lat_range = { + "table": "biosample", + "field": "latitude", + "op": "between", + "value": [49, 51], + } + condition_long_range = { + "table": "biosample", + "field": "longitude", + "op": "between", + "value": [9, 11], + } + q = query.StudyQuerySchema(conditions=[condition_lat_range, condition_long_range]) + results = {s.id for s in q.execute(db)} + assert len(results) == 1 + assert sample_3.study_id in results + + @pytest.mark.parametrize( "op,value,expected", [