Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wait_for_block method #2489

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 74 additions & 21 deletions bittensor/utils/async_substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import asyncio
import inspect
import json
import random
from collections import defaultdict
Expand Down Expand Up @@ -1171,14 +1172,14 @@ async def _get_block_handler(
include_author: bool = False,
header_only: bool = False,
finalized_only: bool = False,
subscription_handler: Optional[Callable] = None,
subscription_handler: Optional[Callable[[dict], Awaitable[Any]]] = None,
):
try:
await self.init_runtime(block_hash=block_hash)
except BlockNotFound:
return None

async def decode_block(block_data, block_data_hash=None):
async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]:
if block_data:
if block_data_hash:
block_data["header"]["hash"] = block_data_hash
Expand All @@ -1193,12 +1194,12 @@ async def decode_block(block_data, block_data_hash=None):

if "extrinsics" in block_data:
for idx, extrinsic_data in enumerate(block_data["extrinsics"]):
extrinsic_decoder = extrinsic_cls(
data=ScaleBytes(extrinsic_data),
metadata=self.__metadata,
runtime_config=self.runtime_config,
)
try:
extrinsic_decoder = extrinsic_cls(
data=ScaleBytes(extrinsic_data),
metadata=self.__metadata,
runtime_config=self.runtime_config,
)
extrinsic_decoder.decode(check_remaining=True)
block_data["extrinsics"][idx] = extrinsic_decoder

Expand Down Expand Up @@ -1314,23 +1315,29 @@ async def decode_block(block_data, block_data_hash=None):
if callable(subscription_handler):
rpc_method_prefix = "Finalized" if finalized_only else "New"

async def result_handler(message, update_nr, subscription_id):
new_block = await decode_block({"header": message["params"]["result"]})
async def result_handler(
message: dict, subscription_id: str
) -> tuple[Any, bool]:
reached = False
subscription_result = None
if "params" in message:
new_block = await decode_block(
{"header": message["params"]["result"]}
)

subscription_result = subscription_handler(
new_block, update_nr, subscription_id
)
subscription_result = await subscription_handler(new_block)

if subscription_result is not None:
# Handler returned end result: unsubscribe from further updates
self._forgettable_task = asyncio.create_task(
self.rpc_request(
f"chain_unsubscribe{rpc_method_prefix}Heads",
[subscription_id],
if subscription_result is not None:
reached = True
# Handler returned end result: unsubscribe from further updates
self._forgettable_task = asyncio.create_task(
self.rpc_request(
f"chain_unsubscribe{rpc_method_prefix}Heads",
[subscription_id],
)
)
)

return subscription_result
return subscription_result, reached

result = await self._make_rpc_request(
[
Expand All @@ -1343,7 +1350,7 @@ async def result_handler(message, update_nr, subscription_id):
result_handler=result_handler,
)

return result
return result["_get_block_handler"][-1]

else:
if header_only:
Expand Down Expand Up @@ -2770,3 +2777,49 @@ async def close(self):
await self.ws.shutdown()
except AttributeError:
pass

async def wait_for_block(
self,
block: int,
result_handler: Callable[[dict], Awaitable[Any]],
task_return: bool = True,
) -> Union[asyncio.Task, Union[bool, Any]]:
"""
Executes the result_handler when the chain has reached the block specified.

Args:
block: block number
result_handler: coroutine executed upon reaching the block number. This can be basically anything, but
must accept one single arg, a dict with the block data; whether you use this data or not is entirely
up to you.
task_return: True to immediately return the result of wait_for_block as an asyncio Task, False to wait
for the block to be reached, and return the result of the result handler.

Returns:
Either an asyncio.Task (which contains the running subscription, and whose `result()` will contain the
return of the result_handler), or the result itself, depending on `task_return` flag.
Note that if your result_handler returns `None`, this method will return `True`, otherwise
the return will be the result of your result_handler.
"""

async def _handler(block_data: dict[str, Any]):
required_number = block
number = block_data["header"]["number"]
if number >= required_number:
return (
r if (r := await result_handler(block_data)) is not None else True
)

args = inspect.getfullargspec(result_handler).args
if len(args) != 1:
raise ValueError(
"result_handler must take exactly one arg: the dict block data."
)

co = self._get_block_handler(
self.last_block_hash, subscription_handler=_handler
)
if task_return is True:
return asyncio.create_task(co)
else:
return await co
38 changes: 38 additions & 0 deletions tests/unit_tests/utils/test_async_substrate_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import asyncio
from bittensor.utils import async_substrate_interface
from typing import Any


@pytest.mark.asyncio
async def test_wait_for_block_invalid_result_handler():
chain_interface = async_substrate_interface.AsyncSubstrateInterface(
"dummy_endpoint"
)

with pytest.raises(ValueError):

async def dummy_handler(
block_data: dict[str, Any], extra_arg
): # extra argument
return block_data.get("header", {}).get("number", -1) == 2

await chain_interface.wait_for_block(
block=2, result_handler=dummy_handler, task_return=False
)


@pytest.mark.asyncio
async def test_wait_for_block_async_return():
chain_interface = async_substrate_interface.AsyncSubstrateInterface(
"dummy_endpoint"
)

async def dummy_handler(block_data: dict[str, Any]) -> bool:
return block_data.get("header", {}).get("number", -1) == 2

result = await chain_interface.wait_for_block(
block=2, result_handler=dummy_handler, task_return=True
)

assert isinstance(result, asyncio.Task)
Loading