diff --git a/src/magentic/streaming.py b/src/magentic/streaming.py index 809dbe0f..cff97a0e 100644 --- a/src/magentic/streaming.py +++ b/src/magentic/streaming.py @@ -33,6 +33,31 @@ async def achain(*aiterables: AsyncIterable[T]) -> AsyncIterator[T]: yield item +def peek(iterator: Iterator[T]) -> tuple[T, Iterator[T]]: + """Returns the first item in the Iterator and a copy of the Iterator.""" + first_item = next(iterator) + return first_item, chain([first_item], iterator) + + +async def apeek(aiterator: AsyncIterator[T]) -> tuple[T, AsyncIterator[T]]: + """Async version of `peek`.""" + first_item = await anext(aiterator) + return first_item, achain(async_iter([first_item]), aiterator) + + +async def adropwhile( + predicate: Callable[[T], object], aiterable: AsyncIterable[T] +) -> AsyncIterator[T]: + """Async version of `itertools.dropwhile`.""" + aiterator = aiter(aiterable) + async for item in aiterator: + if not predicate(item): + yield item + break + async for item in aiterator: + yield item + + async def atakewhile( predicate: Callable[[T], object], aiterable: AsyncIterable[T] ) -> AsyncIterator[T]: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index ff5a1123..dded9d27 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -6,12 +6,15 @@ from magentic.streaming import ( CachedAsyncIterable, CachedIterable, + adropwhile, agroupby, aiter_streamed_json_array, + apeek, async_iter, atakewhile, azip, iter_streamed_json_array, + peek, ) @@ -34,6 +37,46 @@ async def test_azip(aiterable, expected): assert [x async for x in aiterable] == expected +@pytest.mark.parametrize( + ("iterator", "expected_first", "expected_remaining"), + [ + (iter([1, 2, 3]), 1, [1, 2, 3]), + (iter([1]), 1, [1]), + ], +) +def test_peek(iterator, expected_first, expected_remaining): + first, remaining = peek(iterator) + assert first == expected_first + assert list(remaining) == expected_remaining + + +@pytest.mark.parametrize( + ("aiterator", "expected_first", "expected_remaining"), + [ + (async_iter([1, 2, 3]), 1, [1, 2, 3]), + (async_iter([1]), 1, [1]), + ], +) +@pytest.mark.asyncio +async def test_apeek(aiterator, expected_first, expected_remaining): + first, remaining = await apeek(aiterator) + assert first == expected_first + assert [x async for x in remaining] == expected_remaining + + +@pytest.mark.parametrize( + ("predicate", "input", "expected"), + [ + (lambda x: x < 3, async_iter(range(5)), [3, 4]), + (lambda x: x < 6, async_iter(range(5)), []), + (lambda x: x < 0, async_iter(range(5)), [0, 1, 2, 3, 4]), + ], +) +@pytest.mark.asyncio +async def test_adropwhile(predicate, input, expected): + assert [x async for x in adropwhile(predicate, input)] == expected + + @pytest.mark.parametrize( ("predicate", "input", "expected"), [