From 2f4d2725e99b31e04a4e6e612733ea424b9b1653 Mon Sep 17 00:00:00 2001 From: Siddhant Goel Date: Thu, 18 Apr 2024 21:08:30 +0200 Subject: [PATCH] chore: split up targets into sync/async --- streaming_form_data/targets.py | 77 ++++++++++++++++++++++++++-------- tests/test_parser.py | 4 +- tests/test_targets.py | 16 +++++++ 3 files changed, 77 insertions(+), 20 deletions(-) diff --git a/streaming_form_data/targets.py b/streaming_form_data/targets.py index 63a354bc..b222eb46 100644 --- a/streaming_form_data/targets.py +++ b/streaming_form_data/targets.py @@ -1,26 +1,21 @@ import hashlib from pathlib import Path -import smart_open # type: ignore from typing import Callable, List, Optional +import smart_open # type: ignore + class BaseTarget: """ Targets determine what to do with some input once the parser is done - processing it. Any new Target should inherit from this base class and - override the :code:`data_received` function. - - Attributes: - multipart_filename: the name of the file advertised by the user, - extracted from the :code:`Content-Disposition` header. Please note - that this value comes directly from the user input and is not - sanitized, so be careful in using it directly. - multipart_content_type: MIME Content-Type of the file, extracted from - the :code:`Content-Type` HTTP header + processing it. """ def __init__(self, validator: Optional[Callable] = None): + # the name of the file extracted from the Content-Disposition header self.multipart_filename = None + + # the MIME Content-Type extracted from the Content-Type header self.multipart_content_type = None self._started = False @@ -31,6 +26,13 @@ def _validate(self, chunk: bytes): if self._validator: self._validator(chunk) + +class SyncTarget(BaseTarget): + """ + SyncTarget handle inputs in a synchronous manner. Child classes should override the + on_data_received method to do the actual work. + """ + def start(self): self._started = True self.on_start() @@ -53,7 +55,35 @@ def on_finish(self): pass -class NullTarget(BaseTarget): +class AsyncTarget(BaseTarget): + """ + AsyncTarget handle inputs in an asynchronous manner. Child classes should override + the on_data_received method to do the actual work. + """ + + async def start(self): + self._started = True + await self.on_start() + + async def on_start(self): + pass + + async def data_received(self, chunk: bytes): + self._validate(chunk) + await self.on_data_received(chunk) + + async def on_data_received(self, chunk: bytes): + raise NotImplementedError() + + async def finish(self): + await self.on_finish() + self._finished = True + + async def on_finish(self): + pass + + +class NullTarget(SyncTarget): """NullTarget ignores whatever input is passed in. This is mostly useful for internal use and should (normally) not be @@ -64,7 +94,18 @@ def on_data_received(self, chunk: bytes): pass -class ValueTarget(BaseTarget): +class AsyncNullTarget(AsyncTarget): + """AsyncNullTarget ignores whatever input is passed in. + + This is mostly useful for internal use and should (normally) not be + required by external users. + """ + + async def on_data_received(self, chunk: bytes): + pass + + +class ValueTarget(SyncTarget): """ValueTarget stores the input in an in-memory list of bytes. This is useful in case you'd like to have the value contained in an @@ -84,7 +125,7 @@ def value(self): return b"".join(self._values) -class FileTarget(BaseTarget): +class FileTarget(SyncTarget): """FileTarget writes (streams) the input to an on-disk file.""" def __init__(self, filename: str, allow_overwrite: bool = True, *args, **kwargs): @@ -107,7 +148,7 @@ def on_finish(self): self._fd.close() -class DirectoryTarget(BaseTarget): +class DirectoryTarget(SyncTarget): """DirectoryTarget writes (streams) the different inputs to an on-disk directory.""" @@ -143,7 +184,7 @@ def on_finish(self): self._fd.close() -class SHA256Target(BaseTarget): +class SHA256Target(SyncTarget): """SHA256Target calculates the SHA256 hash of the given input.""" def __init__(self, *args, **kwargs): @@ -159,7 +200,7 @@ def value(self): return self._hash.hexdigest() -class S3Target(BaseTarget): +class S3Target(SyncTarget): """ S3Target enables chunked uploads to S3 buckets (using smart_open)""" @@ -187,7 +228,7 @@ def on_finish(self): self._fd.close() -class CSVTarget(BaseTarget): +class CSVTarget(SyncTarget): """ CSVTarget enables the processing and release of CSV lines as soon as they are completed by a chunk. diff --git a/tests/test_parser.py b/tests/test_parser.py index f77dc95a..d1669170 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -7,7 +7,7 @@ from streaming_form_data import ParseFailedException, StreamingFormDataParser from streaming_form_data.targets import ( - BaseTarget, + SyncTarget, FileTarget, DirectoryTarget, SHA256Target, @@ -732,7 +732,7 @@ def test_target_raises_exception(): content_type, body = encoded_dataset(filename) - class BadTarget(BaseTarget): + class BadTarget(SyncTarget): def data_received(self, data): raise ValueError() diff --git a/tests/test_targets.py b/tests/test_targets.py index 3ee555da..5086f34b 100644 --- a/tests/test_targets.py +++ b/tests/test_targets.py @@ -6,6 +6,7 @@ import boto3 from streaming_form_data.targets import ( + AsyncNullTarget, BaseTarget, FileTarget, DirectoryTarget, @@ -40,6 +41,21 @@ def test_null_target_basic(): assert target.multipart_filename == "file001.txt" +@pytest.mark.asyncio +async def test_async_null_target_basic(): + target = AsyncNullTarget() + + target.multipart_filename = "file001.txt" + + await target.start() + assert target.multipart_filename == "file001.txt" + + await target.data_received(b"hello") + await target.finish() + + assert target.multipart_filename == "file001.txt" + + def test_value_target_basic(): target = ValueTarget()