Skip to content

Commit

Permalink
Add HashiKeys class
Browse files Browse the repository at this point in the history
  • Loading branch information
evgeny-stakewise committed Dec 20, 2024
1 parent ec6cfa4 commit 51af376
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions src/validators/keystores/hashi_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def secret_url(self, key_path: str, location: str = 'data') -> str:
)


class HashiKeys:
def __init__(self, keys: Keys):
self.keys = keys

def __getitem__(self, key: HexStr) -> BLSPrivkey:
return self.keys[key]

def __setitem__(self, key: HexStr, value: BLSPrivkey) -> None:
"""Add new key/value pair proactively searching for duplicate keys to prevent
potential slashing."""
if key in self.keys:
raise RuntimeError(f'Duplicate validator key {key} found in hashi vault')
self.keys[key] = value

def update(self, new_keys: Keys) -> None:
for key, value in new_keys.items():
self[key] = value

def __repr__(self) -> str:
return f'HashiKeys({self.keys})'


@dataclass
class HashiVaultKeysLoader(metaclass=abc.ABCMeta):
config: HashiVaultConfiguration
Expand All @@ -73,25 +95,14 @@ def session(self) -> ClientSession:
headers={'X-Vault-Token': self.config.token},
)

@staticmethod
def merge_keys_responses(keys_responses: list[Keys], merged_keys: Keys) -> None:
"""Merge keys objects, proactively searching for duplicate keys to prevent
potential slashing."""
for keys in keys_responses:
for pk, sk in keys.items():
if pk in merged_keys:
logger.error('Duplicate validator key %s found in hashi vault', pk)
raise RuntimeError('Found duplicate key in path')
merged_keys[pk] = sk

@abc.abstractmethod
async def load(self, merged_keys: Keys) -> None:
async def load(self, merged_keys: HashiKeys) -> None:
"""Populate merged_keys structure with validator keys from given loader."""
raise NotImplementedError


class HashiVaultBundledKeysLoader(HashiVaultKeysLoader):
async def load(self, merged_keys: Keys) -> None:
async def load(self, merged_keys: HashiKeys) -> None:
"""Load all the key bundles from input locations."""
while key_chunk := list(itertools.islice(self.input_iter, self.config.parallelism)):
async with self.session() as session:
Expand All @@ -104,7 +115,8 @@ async def load(self, merged_keys: Keys) -> None:
for key_path in key_chunk
]
)
self.merge_keys_responses(keys_responses, merged_keys)
for keys_response in keys_responses:
merged_keys.update(keys_response)

@staticmethod
async def _load_bundled_hashi_vault_keys(session: ClientSession, secret_url: str) -> Keys:
Expand Down Expand Up @@ -136,7 +148,7 @@ async def _load_bundled_hashi_vault_keys(session: ClientSession, secret_url: str


class HashiVaultPrefixedKeysLoader(HashiVaultKeysLoader):
async def load(self, merged_keys: Keys) -> None:
async def load(self, merged_keys: HashiKeys) -> None:
"""Discover all the keys under given prefix. Then, load the keys into merged structure."""
prefix_leaf_location_tuples = []
while prefix_chunk := list(itertools.islice(self.input_iter, self.config.parallelism)):
Expand Down Expand Up @@ -169,7 +181,8 @@ async def load(self, merged_keys: Keys) -> None:
for (key_prefix, key_path) in prefixed_chunk
]
)
self.merge_keys_responses(keys_responses, merged_keys)
for keys_response in keys_responses:
merged_keys.update(keys_response)

@staticmethod
async def _find_prefixed_hashi_vault_keys(
Expand Down Expand Up @@ -218,12 +231,15 @@ async def _load_prefixed_hashi_vault_key(session: ClientSession, secret_url: str


class HashiVaultKeystore(LocalKeystore):
def __init__(self, keys: HashiKeys):
super().__init__(keys.keys)

@staticmethod
async def load() -> 'HashiVaultKeystore':
"""Extracts private keys from the keystores."""
hashi_vault_config = HashiVaultConfiguration.from_settings() # noqa: NEW100

merged_keys = Keys({})
merged_keys = HashiKeys(Keys({}))

for loader_class, input_iter in {
HashiVaultBundledKeysLoader: iter(hashi_vault_config.key_paths),
Expand Down

0 comments on commit 51af376

Please sign in to comment.