Skip to content

Commit

Permalink
Refactor validators registration (#445)
Browse files Browse the repository at this point in the history
* Extract poll_oracles_approval function

* Add items to gitignore

* Undo function move
  • Loading branch information
evgeny-stakewise authored Jan 6, 2025
1 parent 6a1cf53 commit 5dd7c53
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 32 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ database
*.db
.DS_Store
.vscode
.python-version
.coverage
.vscode
.coverage
.python-version
.history
16 changes: 16 additions & 0 deletions src/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import time
from collections import defaultdict
from datetime import datetime, timezone
from decimal import ROUND_FLOOR, Decimal, localcontext
Expand Down Expand Up @@ -116,6 +117,21 @@ def add_fields(self, log_record, record, message_dict): # type: ignore
log_record['level'] = record.levelname


class RateLimiter:
def __init__(self, min_interval: int) -> None:
self.min_interval = min_interval
self.previous_time: float | None = None

async def ensure_interval(self) -> None:
current_time = time.time()

if self.previous_time is not None:
elapsed = current_time - self.previous_time
await asyncio.sleep(self.min_interval - elapsed)

self.previous_time = current_time


def round_down(d: int | Decimal, precision: int) -> Decimal:
if isinstance(d, int):
d = Decimal(d)
Expand Down
86 changes: 54 additions & 32 deletions src/validators/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import asyncio
import logging
import time
from typing import Sequence, cast

from eth_typing import HexStr
Expand All @@ -19,8 +17,8 @@
from src.common.harvest import get_harvest_params
from src.common.metrics import metrics
from src.common.tasks import BaseTask
from src.common.typings import HarvestParams
from src.common.utils import get_current_timestamp
from src.common.typings import HarvestParams, OraclesApproval
from src.common.utils import RateLimiter, get_current_timestamp
from src.config.settings import DEPOSIT_AMOUNT, settings
from src.validators.database import NetworkValidatorCrud
from src.validators.exceptions import MissingDepositDataValidatorsException
Expand Down Expand Up @@ -85,7 +83,7 @@ async def process_block(self, interrupt_handler: InterruptHandler) -> None:
)


# pylint: disable-next=too-many-locals,too-many-branches,too-many-return-statements,too-many-statements
# pylint: disable-next=too-many-locals,too-many-return-statements
async def process_validators(
keystore: BaseKeystore | None,
deposit_data: DepositData | None,
Expand Down Expand Up @@ -173,61 +171,85 @@ async def process_validators(

logger.info('Started registration of %d validator(s)', len(validators))

registry_root = None
oracles_request = None
oracles_request, oracles_approval = await poll_oracles_approval(
keystore=keystore,
validators=validators,
multi_proof=multi_proof,
validators_manager_signature=validators_manager_signature,
)
validators_registry_root = Bytes32(Web3.to_bytes(hexstr=oracles_request.validators_root))

tx_hash = await register_validators(
approval=oracles_approval,
multi_proof=multi_proof,
validators=validators,
harvest_params=harvest_params,
validators_registry_root=validators_registry_root,
validators_manager_signature=validators_manager_signature,
)
if tx_hash:
pub_keys = ', '.join([val.public_key for val in validators])
logger.info('Successfully registered validator(s) with public key(s) %s', pub_keys)

return tx_hash


async def poll_oracles_approval(
keystore: BaseKeystore | None,
validators: Sequence[Validator],
multi_proof: MultiProof[tuple[bytes, int]] | None = None,
validators_manager_signature: HexStr | None = None,
) -> tuple[ApprovalRequest, OraclesApproval]:
"""
Polls oracles for approval of validator registration
"""
previous_registry_root: Bytes32 | None = None
oracles_request: ApprovalRequest | None = None
protocol_config = await get_protocol_config()
deadline = get_current_timestamp() + protocol_config.signature_validity_period
deadline: int | None = None

approvals_min_interval = 1
rate_limiter = RateLimiter(approvals_min_interval)

while True:
approval_start_time = time.time()
# Keep min interval between requests
await rate_limiter.ensure_interval()

# Create new approvals request or reuse the previous one
current_registry_root = await validators_registry_contract.get_registry_root()
logger.debug('Fetched validators registry root: %s', Web3.to_hex(current_registry_root))

latest_registry_root = await validators_registry_contract.get_registry_root()
current_timestamp = get_current_timestamp()
if (
not registry_root
or registry_root != latest_registry_root
oracles_request is None
or previous_registry_root is None
or previous_registry_root != current_registry_root
or deadline is None
or deadline <= current_timestamp
):
registry_root = latest_registry_root
deadline = current_timestamp + protocol_config.signature_validity_period
logger.debug('Fetched latest validators registry root: %s', Web3.to_hex(registry_root))

oracles_request = await create_approval_request(
protocol_config=protocol_config,
keystore=keystore,
validators=validators,
registry_root=registry_root,
registry_root=current_registry_root,
multi_proof=multi_proof,
deadline=deadline,
validators_manager_signature=validators_manager_signature,
)
previous_registry_root = current_registry_root

# Send approval requests
try:
oracles_approval = await send_approval_requests(protocol_config, oracles_request)
break
return oracles_request, oracles_approval
except NotEnoughOracleApprovalsError as e:
logger.error(
'Not enough oracle approvals for validator registration: %d. Threshold is %d.',
e.num_votes,
e.threshold,
)
approvals_time = time.time() - approval_start_time
await asyncio.sleep(approvals_min_interval - approvals_time)

tx_hash = await register_validators(
approval=oracles_approval,
multi_proof=multi_proof,
validators=validators,
harvest_params=harvest_params,
validators_registry_root=registry_root,
validators_manager_signature=validators_manager_signature,
)
if tx_hash:
pub_keys = ', '.join([val.public_key for val in validators])
logger.info('Successfully registered validator(s) with public key(s) %s', pub_keys)

return tx_hash


async def get_validators_count_from_vault_assets(harvest_params: HarvestParams | None) -> int:
Expand Down

0 comments on commit 5dd7c53

Please sign in to comment.