From 51af3767dbed20cacd6e62620bc95bf0730f5957 Mon Sep 17 00:00:00 2001 From: Evgeny Gusarov Date: Sat, 21 Dec 2024 02:09:16 +0300 Subject: [PATCH] Add HashiKeys class --- src/validators/keystores/hashi_vault.py | 50 ++++++++++++++++--------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/src/validators/keystores/hashi_vault.py b/src/validators/keystores/hashi_vault.py index d5419abd..8fddc4e1 100644 --- a/src/validators/keystores/hashi_vault.py +++ b/src/validators/keystores/hashi_vault.py @@ -62,6 +62,28 @@ def secret_url(self, key_path: str, location: str = 'data') -> str: ) +class HashiKeys: + def __init__(self, keys: Keys): + self.keys = keys + + def __getitem__(self, key: HexStr) -> BLSPrivkey: + return self.keys[key] + + def __setitem__(self, key: HexStr, value: BLSPrivkey) -> None: + """Add new key/value pair proactively searching for duplicate keys to prevent + potential slashing.""" + if key in self.keys: + raise RuntimeError(f'Duplicate validator key {key} found in hashi vault') + self.keys[key] = value + + def update(self, new_keys: Keys) -> None: + for key, value in new_keys.items(): + self[key] = value + + def __repr__(self) -> str: + return f'HashiKeys({self.keys})' + + @dataclass class HashiVaultKeysLoader(metaclass=abc.ABCMeta): config: HashiVaultConfiguration @@ -73,25 +95,14 @@ def session(self) -> ClientSession: headers={'X-Vault-Token': self.config.token}, ) - @staticmethod - def merge_keys_responses(keys_responses: list[Keys], merged_keys: Keys) -> None: - """Merge keys objects, proactively searching for duplicate keys to prevent - potential slashing.""" - for keys in keys_responses: - for pk, sk in keys.items(): - if pk in merged_keys: - logger.error('Duplicate validator key %s found in hashi vault', pk) - raise RuntimeError('Found duplicate key in path') - merged_keys[pk] = sk - @abc.abstractmethod - async def load(self, merged_keys: Keys) -> None: + async def load(self, merged_keys: HashiKeys) -> None: """Populate merged_keys structure with validator keys from given loader.""" raise NotImplementedError class HashiVaultBundledKeysLoader(HashiVaultKeysLoader): - async def load(self, merged_keys: Keys) -> None: + async def load(self, merged_keys: HashiKeys) -> None: """Load all the key bundles from input locations.""" while key_chunk := list(itertools.islice(self.input_iter, self.config.parallelism)): async with self.session() as session: @@ -104,7 +115,8 @@ async def load(self, merged_keys: Keys) -> None: for key_path in key_chunk ] ) - self.merge_keys_responses(keys_responses, merged_keys) + for keys_response in keys_responses: + merged_keys.update(keys_response) @staticmethod async def _load_bundled_hashi_vault_keys(session: ClientSession, secret_url: str) -> Keys: @@ -136,7 +148,7 @@ async def _load_bundled_hashi_vault_keys(session: ClientSession, secret_url: str class HashiVaultPrefixedKeysLoader(HashiVaultKeysLoader): - async def load(self, merged_keys: Keys) -> None: + async def load(self, merged_keys: HashiKeys) -> None: """Discover all the keys under given prefix. Then, load the keys into merged structure.""" prefix_leaf_location_tuples = [] while prefix_chunk := list(itertools.islice(self.input_iter, self.config.parallelism)): @@ -169,7 +181,8 @@ async def load(self, merged_keys: Keys) -> None: for (key_prefix, key_path) in prefixed_chunk ] ) - self.merge_keys_responses(keys_responses, merged_keys) + for keys_response in keys_responses: + merged_keys.update(keys_response) @staticmethod async def _find_prefixed_hashi_vault_keys( @@ -218,12 +231,15 @@ async def _load_prefixed_hashi_vault_key(session: ClientSession, secret_url: str class HashiVaultKeystore(LocalKeystore): + def __init__(self, keys: HashiKeys): + super().__init__(keys.keys) + @staticmethod async def load() -> 'HashiVaultKeystore': """Extracts private keys from the keystores.""" hashi_vault_config = HashiVaultConfiguration.from_settings() # noqa: NEW100 - merged_keys = Keys({}) + merged_keys = HashiKeys(Keys({})) for loader_class, input_iter in { HashiVaultBundledKeysLoader: iter(hashi_vault_config.key_paths),