Skip to content

Commit

Permalink
Refactor HashiVaultKeysLoader.load
Browse files Browse the repository at this point in the history
  • Loading branch information
evgeny-stakewise committed Dec 20, 2024
1 parent 51af376 commit 4da8998
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions src/validators/keystores/hashi_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import urllib.parse
from dataclasses import dataclass
from typing import Iterator
from typing import Iterable

from aiohttp import ClientSession, ClientTimeout
from eth_typing import HexStr
Expand Down Expand Up @@ -63,8 +63,8 @@ def secret_url(self, key_path: str, location: str = 'data') -> str:


class HashiKeys:
def __init__(self, keys: Keys):
self.keys = keys
def __init__(self, keys: Keys | None = None):
self.keys = keys or Keys({})

def __getitem__(self, key: HexStr) -> BLSPrivkey:
return self.keys[key]
Expand All @@ -76,7 +76,10 @@ def __setitem__(self, key: HexStr, value: BLSPrivkey) -> None:
raise RuntimeError(f'Duplicate validator key {key} found in hashi vault')
self.keys[key] = value

def update(self, new_keys: Keys) -> None:
def items(self) -> Iterable[tuple[HexStr, BLSPrivkey]]:
return self.keys.items()

def update(self, new_keys: 'HashiKeys' | Keys) -> None:
for key, value in new_keys.items():
self[key] = value

Expand All @@ -87,7 +90,6 @@ def __repr__(self) -> str:
@dataclass
class HashiVaultKeysLoader(metaclass=abc.ABCMeta):
config: HashiVaultConfiguration
input_iter: Iterator[str]

def session(self) -> ClientSession:
return ClientSession(
Expand All @@ -96,15 +98,20 @@ def session(self) -> ClientSession:
)

@abc.abstractmethod
async def load(self, merged_keys: HashiKeys) -> None:
"""Populate merged_keys structure with validator keys from given loader."""
async def load(self) -> HashiKeys:
"""Load validator keys"""
raise NotImplementedError


class HashiVaultBundledKeysLoader(HashiVaultKeysLoader):
async def load(self, merged_keys: HashiKeys) -> None:
"""Load all the key bundles from input locations."""
while key_chunk := list(itertools.islice(self.input_iter, self.config.parallelism)):
async def load(self) -> HashiKeys:
return await self._load_from_key_paths(self.config.key_paths)

async def _load_from_key_paths(self, key_paths: list[str]) -> HashiKeys:
"""Load all the key bundles from key paths."""
merged_keys = HashiKeys()

while key_chunk := list(itertools.islice(key_paths, self.config.parallelism)):
async with self.session() as session:
keys_responses = await asyncio.gather(
*[
Expand All @@ -118,6 +125,8 @@ async def load(self, merged_keys: HashiKeys) -> None:
for keys_response in keys_responses:
merged_keys.update(keys_response)

return merged_keys

@staticmethod
async def _load_bundled_hashi_vault_keys(session: ClientSession, secret_url: str) -> Keys:
"""
Expand Down Expand Up @@ -148,10 +157,13 @@ async def _load_bundled_hashi_vault_keys(session: ClientSession, secret_url: str


class HashiVaultPrefixedKeysLoader(HashiVaultKeysLoader):
async def load(self, merged_keys: HashiKeys) -> None:
async def load(self) -> HashiKeys:
return await self._load_from_key_prefixes(self.config.key_prefixes)

async def _load_from_key_prefixes(self, key_prefixes: list[str]) -> HashiKeys:
"""Discover all the keys under given prefix. Then, load the keys into merged structure."""
prefix_leaf_location_tuples = []
while prefix_chunk := list(itertools.islice(self.input_iter, self.config.parallelism)):
while prefix_chunk := list(itertools.islice(key_prefixes, self.config.parallelism)):
async with self.session() as session:
prefix_leaf_location_tuples += await asyncio.gather(
*[
Expand All @@ -169,6 +181,7 @@ async def load(self, merged_keys: HashiKeys) -> None:
prefix_leaf_location_tuples,
[],
)
merged_keys = HashiKeys()
prefixed_keys_iter = iter(keys_paired_with_prefix)
while prefixed_chunk := list(itertools.islice(prefixed_keys_iter, self.config.parallelism)):
async with self.session() as session:
Expand All @@ -184,6 +197,8 @@ async def load(self, merged_keys: HashiKeys) -> None:
for keys_response in keys_responses:
merged_keys.update(keys_response)

return merged_keys

@staticmethod
async def _find_prefixed_hashi_vault_keys(
session: ClientSession, prefix: str, prefix_url: str
Expand Down Expand Up @@ -239,16 +254,11 @@ async def load() -> 'HashiVaultKeystore':
"""Extracts private keys from the keystores."""
hashi_vault_config = HashiVaultConfiguration.from_settings() # noqa: NEW100

merged_keys = HashiKeys(Keys({}))
merged_keys = HashiKeys()

for loader_class, input_iter in {
HashiVaultBundledKeysLoader: iter(hashi_vault_config.key_paths),
HashiVaultPrefixedKeysLoader: iter(hashi_vault_config.key_prefixes),
}.items():
loader = loader_class(
config=hashi_vault_config,
input_iter=input_iter,
)
await loader.load(merged_keys)
for loader_class in [HashiVaultBundledKeysLoader, HashiVaultPrefixedKeysLoader]:
loader = loader_class(config=hashi_vault_config)
keys = await loader.load()
merged_keys.update(keys)

return HashiVaultKeystore(merged_keys)

0 comments on commit 4da8998

Please sign in to comment.