diff --git a/.gitignore b/.gitignore index 371989c1..1ce7fa9b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,9 @@ database *.db .DS_Store .vscode +.python-version +.coverage +.vscode .coverage .python-version .history diff --git a/src/common/utils.py b/src/common/utils.py index 6712e41f..40760e14 100644 --- a/src/common/utils.py +++ b/src/common/utils.py @@ -1,5 +1,6 @@ import asyncio import logging +import time from collections import defaultdict from datetime import datetime, timezone from decimal import ROUND_FLOOR, Decimal, localcontext @@ -116,6 +117,21 @@ def add_fields(self, log_record, record, message_dict): # type: ignore log_record['level'] = record.levelname +class RateLimiter: + def __init__(self, min_interval: int) -> None: + self.min_interval = min_interval + self.previous_time: float | None = None + + async def ensure_interval(self) -> None: + current_time = time.time() + + if self.previous_time is not None: + elapsed = current_time - self.previous_time + await asyncio.sleep(self.min_interval - elapsed) + + self.previous_time = current_time + + def round_down(d: int | Decimal, precision: int) -> Decimal: if isinstance(d, int): d = Decimal(d) diff --git a/src/validators/tasks.py b/src/validators/tasks.py index d6165bad..4a28b9fa 100644 --- a/src/validators/tasks.py +++ b/src/validators/tasks.py @@ -1,6 +1,4 @@ -import asyncio import logging -import time from typing import Sequence, cast from eth_typing import HexStr @@ -19,8 +17,8 @@ from src.common.harvest import get_harvest_params from src.common.metrics import metrics from src.common.tasks import BaseTask -from src.common.typings import HarvestParams -from src.common.utils import get_current_timestamp +from src.common.typings import HarvestParams, OraclesApproval +from src.common.utils import RateLimiter, get_current_timestamp from src.config.settings import DEPOSIT_AMOUNT, settings from src.validators.database import NetworkValidatorCrud from src.validators.exceptions import MissingDepositDataValidatorsException @@ -85,7 +83,7 @@ async def process_block(self, interrupt_handler: InterruptHandler) -> None: ) -# pylint: disable-next=too-many-locals,too-many-branches,too-many-return-statements,too-many-statements +# pylint: disable-next=too-many-locals,too-many-return-statements async def process_validators( keystore: BaseKeystore | None, deposit_data: DepositData | None, @@ -173,61 +171,85 @@ async def process_validators( logger.info('Started registration of %d validator(s)', len(validators)) - registry_root = None - oracles_request = None + oracles_request, oracles_approval = await poll_oracles_approval( + keystore=keystore, + validators=validators, + multi_proof=multi_proof, + validators_manager_signature=validators_manager_signature, + ) + validators_registry_root = Bytes32(Web3.to_bytes(hexstr=oracles_request.validators_root)) + + tx_hash = await register_validators( + approval=oracles_approval, + multi_proof=multi_proof, + validators=validators, + harvest_params=harvest_params, + validators_registry_root=validators_registry_root, + validators_manager_signature=validators_manager_signature, + ) + if tx_hash: + pub_keys = ', '.join([val.public_key for val in validators]) + logger.info('Successfully registered validator(s) with public key(s) %s', pub_keys) + + return tx_hash + + +async def poll_oracles_approval( + keystore: BaseKeystore | None, + validators: Sequence[Validator], + multi_proof: MultiProof[tuple[bytes, int]] | None = None, + validators_manager_signature: HexStr | None = None, +) -> tuple[ApprovalRequest, OraclesApproval]: + """ + Polls oracles for approval of validator registration + """ + previous_registry_root: Bytes32 | None = None + oracles_request: ApprovalRequest | None = None protocol_config = await get_protocol_config() - deadline = get_current_timestamp() + protocol_config.signature_validity_period + deadline: int | None = None + approvals_min_interval = 1 + rate_limiter = RateLimiter(approvals_min_interval) while True: - approval_start_time = time.time() + # Keep min interval between requests + await rate_limiter.ensure_interval() + + # Create new approvals request or reuse the previous one + current_registry_root = await validators_registry_contract.get_registry_root() + logger.debug('Fetched validators registry root: %s', Web3.to_hex(current_registry_root)) - latest_registry_root = await validators_registry_contract.get_registry_root() current_timestamp = get_current_timestamp() if ( - not registry_root - or registry_root != latest_registry_root + oracles_request is None + or previous_registry_root is None + or previous_registry_root != current_registry_root + or deadline is None or deadline <= current_timestamp ): - registry_root = latest_registry_root deadline = current_timestamp + protocol_config.signature_validity_period - logger.debug('Fetched latest validators registry root: %s', Web3.to_hex(registry_root)) oracles_request = await create_approval_request( protocol_config=protocol_config, keystore=keystore, validators=validators, - registry_root=registry_root, + registry_root=current_registry_root, multi_proof=multi_proof, deadline=deadline, validators_manager_signature=validators_manager_signature, ) + previous_registry_root = current_registry_root + # Send approval requests try: oracles_approval = await send_approval_requests(protocol_config, oracles_request) - break + return oracles_request, oracles_approval except NotEnoughOracleApprovalsError as e: logger.error( 'Not enough oracle approvals for validator registration: %d. Threshold is %d.', e.num_votes, e.threshold, ) - approvals_time = time.time() - approval_start_time - await asyncio.sleep(approvals_min_interval - approvals_time) - - tx_hash = await register_validators( - approval=oracles_approval, - multi_proof=multi_proof, - validators=validators, - harvest_params=harvest_params, - validators_registry_root=registry_root, - validators_manager_signature=validators_manager_signature, - ) - if tx_hash: - pub_keys = ', '.join([val.public_key for val in validators]) - logger.info('Successfully registered validator(s) with public key(s) %s', pub_keys) - - return tx_hash async def get_validators_count_from_vault_assets(harvest_params: HarvestParams | None) -> int: