diff --git a/src/async_kinesis_client/kinesis_consumer.py b/src/async_kinesis_client/kinesis_consumer.py index b62f81a..fb761ae 100644 --- a/src/async_kinesis_client/kinesis_consumer.py +++ b/src/async_kinesis_client/kinesis_consumer.py @@ -112,6 +112,10 @@ async def get_records(self): # FIXME: Could there be empty records in the list? If yes, should we filter them out? self.record_count += len(records) if self.dynamodb and self.record_count > self.checkpoint_interval: + callback_coro = self.consumer._get_checkpoint_callback() + if callback_coro: + if not await callback_coro(self.shard_id, records[-1]['SequenceNumber']): + raise ShardClosedException('Shard closed by application request') await self.dynamodb.checkpoint(seq=records[-1]['SequenceNumber']) self.record_count = 0 self.retries = 0 @@ -161,6 +165,7 @@ def __init__( self.kinesis_client = aioboto3.client('kinesis') self.checkpoint_table = checkpoint_table + self.checkpoint_callback = None self.host_key = host_key self.shard_readers = {} @@ -172,6 +177,23 @@ def __init__( self.lock_duration = AsyncKinesisConsumer.DEFAULT_LOCK_DURATION self.reader_sleep_time = AsyncKinesisConsumer.DEFAULT_SLEEP_TIME + def set_checkpoint_callback(self, callback): + """ + Sets application callback coroutine to be called before checkpointing next batch of records + The callback should return True if the records received from AsyncKinesisReader were + successfully processed by application and can be checkpointed. + The application can try to finish processing received records before returning value from this callback. + If False value is returned, the Shard Reader will exit + The callback is called with following arguments: + ShardId - Shard Id of the shard attempting checkpointing + SequenceNumber - Last SequenceId of the record in batch + :param callback: + """ + self.checkpoint_callback = callback + + def _get_checkpoint_callback(self): + return self.checkpoint_callback + def set_checkpoint_interval(self, interval): self.checkpoint_interval = interval diff --git a/src/async_kinesis_client/kinesis_producer.py b/src/async_kinesis_client/kinesis_producer.py index b7cb4b8..0cac0ea 100644 --- a/src/async_kinesis_client/kinesis_producer.py +++ b/src/async_kinesis_client/kinesis_producer.py @@ -80,6 +80,8 @@ async def put_records(self, records, partition_key=None, explicit_hash_key=None) otherwise Raises ValueError if single record exceeds 1 Mb + Currently application should check for ProvisionedThroughputExceededException + in response structure itself. """ resp = [] n = 1 diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 28d023d..1a00ee8 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -63,7 +63,7 @@ async def mock_get_shard_iterator(self, StreamName, ShardId, **kwargs): 'ShardIterator': {} } - def test_consmuer(self): + def test_consumer(self): async def read(): cnt = 0