Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue #92 #95

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions semanticscholar/PaginatedResults.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Union, List
import asyncio
import nest_asyncio

from semanticscholar.ApiRequester import ApiRequester
from semanticscholar.SemanticScholarException import NoMorePagesException
Expand Down Expand Up @@ -40,7 +39,6 @@ def __init__(
self._parameters = ''
self._items = []
self._continuation_token = None
nest_asyncio.apply()

@classmethod
async def create(
Expand Down Expand Up @@ -113,6 +111,13 @@ def __iter__(self) -> Any:
while self._has_next_page():
yield from self._get_next_page()

async def __aiter__(self) -> Any:
for item in self._items:
yield item
while self._has_next_page():
for item in await self._async_get_next_page():
yield item

def __len__(self) -> int:
return len(self._items)

Expand All @@ -134,6 +139,10 @@ async def _request_data(self) -> Union[dict, List[dict]]:
)

async def _async_get_next_page(self) -> Union[dict, List[dict]]:

if not self._has_next_page():
raise NoMorePagesException('No more pages to fetch.')

self._build_params()

results = await self._request_data()
Expand Down
3,246 changes: 1,645 additions & 1,601 deletions tests/data/test_get_author_papers_async.yaml

Large diffs are not rendered by default.

9,065 changes: 4,625 additions & 4,440 deletions tests/data/test_get_paper_citations_async.yaml

Large diffs are not rendered by default.

3,629 changes: 1,843 additions & 1,786 deletions tests/data/test_search_paper_bulk_retrieval_next_page_async.yaml

Large diffs are not rendered by default.

3,562 changes: 1,856 additions & 1,706 deletions tests/data/test_search_paper_bulk_retrieval_sorted_results_asc_async.yaml

Large diffs are not rendered by default.

3,562 changes: 1,856 additions & 1,706 deletions tests/data/test_search_paper_bulk_retrieval_sorted_results_default_order_async.yaml

Large diffs are not rendered by default.

3,549 changes: 1,841 additions & 1,708 deletions tests/data/test_search_paper_bulk_retrieval_sorted_results_desc_async.yaml

Large diffs are not rendered by default.

6,301 changes: 3,322 additions & 2,979 deletions tests/data/test_search_paper_bulk_retrieval_traversing_results_async.yaml

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions tests/test_semanticscholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ async def test_get_paper_authors_async(self):
data = await self.sch.get_paper_authors('10.2139/ssrn.2250500')
self.assertEqual(data.offset, 0)
self.assertEqual(data.next, 0)
self.assertEqual(len([item for item in data]), 4)
self.assertEqual(len([item async for item in data]), 4)
self.assertEqual(data[0].name, 'E. Duflo')

@test_vcr.use_cassette
Expand Down Expand Up @@ -743,7 +743,7 @@ async def test_get_author_papers_async(self):
1723755, limit=100, fields=['title'])
self.assertEqual(data.offset, 0)
self.assertEqual(data.next, 100)
self.assertEqual(len([item for item in data]), 875)
self.assertEqual(len([item async for item in data]), 875)
self.assertEqual(
data[0].title,
'SARS-CoV-2 hijacks p38\u03b2/MAPK11 to promote virus replication')
Expand All @@ -754,7 +754,7 @@ async def test_get_paper_citations_async(self):
'10.2139/ssrn.2250500', fields=['title'])
self.assertEqual(data.offset, 0)
self.assertEqual(data.next, 100)
self.assertEqual(len([item.paper.title for item in data]), 2135)
self.assertEqual(len([item.paper.title async for item in data]), 2167)
self.assertEqual(
data[0].paper.title,
'Financial inclusion and roof quality: '
Expand Down Expand Up @@ -793,9 +793,9 @@ async def test_search_paper_next_page_async(self):
@test_vcr.use_cassette
async def test_search_paper_traversing_results_async(self):
data = await self.sch.search_paper('sublinear near optimal edit distance')
all_results = [item.title for item in data]
all_results = [item.title async for item in data]
with self.assertRaises(NoMorePagesException):
await data.next_page()
await data.async_next_page()
self.assertEqual(len(all_results), len(data.items))

@test_vcr.use_cassette
Expand All @@ -811,7 +811,7 @@ async def test_search_paper_year_async(self):
@test_vcr.use_cassette
async def test_search_paper_year_range_async(self):
data = await self.sch.search_paper('turing', year='1936-1937')
self.assertTrue(all([1936 <= item.year <= 1937 for item in data]))
self.assertTrue(all([1936 <= item.year <= 1937 async for item in data]))

@test_vcr.use_cassette
async def test_search_paper_publication_types_async(self):
Expand Down Expand Up @@ -870,7 +870,7 @@ async def test_search_paper_publication_date_or_year_invalid_async(self):
@test_vcr.use_cassette
async def test_search_paper_min_citation_count_async(self):
data = await self.sch.search_paper('turing', min_citation_count=1000)
self.assertTrue(all([item.citationCount >= 1000 for item in data]))
self.assertTrue(all([item.citationCount >= 1000 async for item in data]))

@test_vcr.use_cassette
async def test_search_paper_bulk_retrieval_async(self):
Expand All @@ -887,15 +887,18 @@ async def test_search_paper_bulk_retrieval_async(self):
async def test_search_paper_bulk_retrieval_next_page_async(self):
data = await self.sch.search_paper(
'kubernetes', bulk=True, fields=['title'])
data.next_page()
await data.async_next_page()
self.assertEqual(len(data), 2000)

@test_vcr.use_cassette
async def test_search_paper_bulk_retrieval_traversing_results_async(self):
data = await self.sch.search_paper(
'kubernetes', bulk=True, fields=['title'])
all_results = [item.title for item in data]
self.assertRaises(NoMorePagesException, data.next_page)
all_results = [item.title async for item in data]
print("XXX DATA", type(data))
print(data.async_next_page)
with self.assertRaises(NoMorePagesException):
await data.async_next_page()
self.assertEqual(len(all_results), len(data.items))

@test_vcr.use_cassette
Expand All @@ -905,7 +908,7 @@ async def test_search_paper_bulk_retrieval_sorted_results_default_order_async(se
bulk=True,
sort='citationCount',
fields=['citationCount'])
all_data = [item.citationCount for item in data]
all_data = [item.citationCount async for item in data]
self.assertTrue(sorted(all_data) == all_data)

@test_vcr.use_cassette
Expand All @@ -915,7 +918,7 @@ async def test_search_paper_bulk_retrieval_sorted_results_asc_async(self):
bulk=True,
sort='citationCount:asc',
fields=['citationCount'])
all_data = [item.citationCount for item in data]
all_data = [item.citationCount async for item in data]
self.assertTrue(sorted(all_data) == all_data)

@test_vcr.use_cassette
Expand All @@ -925,7 +928,7 @@ async def test_search_paper_bulk_retrieval_sorted_results_desc_async(self):
bulk=True,
sort='citationCount:desc',
fields=['citationCount'])
all_data = [item.citationCount for item in data]
all_data = [item.citationCount async for item in data]
self.assertTrue(sorted(all_data, reverse=True) == all_data)

@test_vcr.use_cassette
Expand Down
Loading