Skip to content

Commit

Permalink
Rework remote db (#293)
Browse files Browse the repository at this point in the history
* Del parent-public-key in remote-db

* Rename test_tasks -> test_commands

* Fixes
  • Loading branch information
evgeny-stakewise authored Feb 15, 2024
1 parent 0e72967 commit dfff5c6
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 460 deletions.
16 changes: 10 additions & 6 deletions src/remote_db/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def cleanup(ctx: Context) -> None:
click.echo(f'Successfully removed all the entries for the {greenify(settings.vault)} vault.')


@remote_db_group.command(
help='Generates shares for the local keypairs, updates configs in the remote DB.'
)
@remote_db_group.command(help='Uploads key-pairs to remote DB. Updates configs in the remote DB.')
@click.option(
'--encrypt-key',
envvar='REMOTE_DB_ENCRYPT_KEY',
Expand All @@ -133,12 +131,19 @@ def cleanup(ctx: Context) -> None:
help='Path to the deposit_data.json file. '
'Default is the file generated with "create-keys" command.',
)
@click.option(
'--pool-size',
help='Number of processes in a pool.',
envvar='POOL_SIZE',
type=int,
)
@click.pass_context
def upload_keypairs(
ctx: Context,
encrypt_key: str,
execution_endpoints: str,
deposit_data_file: str | None,
pool_size: int | None,
) -> None:
settings.set(
vault=settings.vault,
Expand All @@ -148,12 +153,11 @@ def upload_keypairs(
deposit_data_file=deposit_data_file,
verbose=settings.verbose,
execution_endpoints=execution_endpoints,
pool_size=pool_size,
)
try:
asyncio.run(tasks.upload_keypairs(ctx.obj['db_url'], encrypt_key))
click.echo(
f'Successfully uploaded keypairs and shares for the {greenify(settings.vault)} vault.'
)
click.echo(f'Successfully uploaded keypairs for the {greenify(settings.vault)} vault.')
except Exception as e:
log_verbose(e)

Expand Down
70 changes: 14 additions & 56 deletions src/remote_db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_first_keypair(self) -> RemoteDatabaseKeyPair | None:
with self.db_connection.cursor() as cur:
cur.execute(
f'''
SELECT parent_public_key, public_key, private_key, nonce
SELECT public_key, private_key, nonce
FROM {self.table}
WHERE vault = %s
ORDER BY public_key
Expand All @@ -63,32 +63,20 @@ def get_first_keypair(self) -> RemoteDatabaseKeyPair | None:
return None
return RemoteDatabaseKeyPair(
vault=settings.vault,
parent_public_key=row[0],
public_key=row[1],
private_key=row[2],
nonce=row[3],
public_key=row[0],
private_key=row[1],
nonce=row[2],
)

def get_keypairs(
self, has_parent_public_key: bool | None = None
) -> list[RemoteDatabaseKeyPair]:
def get_keypairs(self) -> list[RemoteDatabaseKeyPair]:
"""Returns keypairs from the database."""
where_list = []
params: dict = {}

if has_parent_public_key is True:
where_list.append('parent_public_key IS NOT NULL')
elif has_parent_public_key is False:
where_list.append('parent_public_key IS NULL')

query = f'''
SELECT parent_public_key, public_key, private_key, nonce
SELECT public_key, private_key, nonce
FROM {self.table}
'''

if where_list:
query += f'WHERE {" AND ".join(where_list)}\n'

query += 'ORDER BY public_key'

with self.db_connection.cursor() as cur:
Expand All @@ -98,22 +86,21 @@ def get_keypairs(
return [
RemoteDatabaseKeyPair(
vault=settings.vault,
parent_public_key=row[0],
public_key=row[1],
private_key=row[2],
nonce=row[3],
public_key=row[0],
private_key=row[1],
nonce=row[2],
)
for row in res
]

def remove_keypairs(self, in_parent_public_keys: set[HexStr] | None = None) -> None:
def remove_keypairs(self, in_public_keys: set[HexStr] | None = None) -> None:
"""Removes keypairs from the database."""
where_list = ['vault = %(vault)s']
params: dict = {'vault': settings.vault}

if in_parent_public_keys is not None:
where_list.append('parent_public_key IN %(in_parent_public_keys)s')
params['in_parent_public_keys'] = tuple(in_parent_public_keys)
if in_public_keys is not None:
where_list.append('public_key IN %(in_public_keys)s')
params['in_public_keys'] = tuple(in_public_keys)

query = f'''
DELETE FROM {self.table}
Expand All @@ -130,18 +117,16 @@ def upload_keypairs(self, keypairs: list[RemoteDatabaseKeyPair]) -> None:
f'''
INSERT INTO {self.table} (
vault,
parent_public_key,
public_key,
private_key,
nonce
)
VALUES (%s, %s, %s, %s, %s)
VALUES (%s, %s, %s, %s)
ON CONFLICT DO NOTHING
''',
[
(
keypair.vault,
keypair.parent_public_key,
keypair.public_key,
keypair.private_key,
keypair.nonce,
Expand All @@ -157,7 +142,6 @@ def create_table(self) -> None:
f'''
CREATE TABLE IF NOT EXISTS {self.table} (
vault VARCHAR(42) NOT NULL,
parent_public_key VARCHAR(98),
public_key VARCHAR(98) UNIQUE NOT NULL,
private_key VARCHAR(66) UNIQUE NOT NULL,
nonce VARCHAR(34) UNIQUE NOT NULL
Expand All @@ -167,7 +151,6 @@ def create_table(self) -> None:


class ConfigsCrud:
remote_signer_config_name = 'remote_signer_config.json'
deposit_data_name = 'deposit_data.json'

def __init__(self, db_connection: Any | None = None, db_url: str | None = None):
Expand All @@ -187,18 +170,6 @@ def get_configs_count(self) -> int:
row = cur.fetchone()
return row[0]

def get_remote_signer_config(self) -> dict | None:
"""Returns the remote signer config from the database."""
with self.db_connection.cursor() as cur:
cur.execute(
f'SELECT data FROM {self.table} WHERE vault = %s AND name = %s',
(settings.vault, self.remote_signer_config_name),
)
row = cur.fetchone()
if row is None:
return None
return json.loads(row[0])

def get_deposit_data(self) -> list | None:
"""Returns the deposit data from the database."""
with self.db_connection.cursor() as cur:
Expand All @@ -211,19 +182,6 @@ def get_deposit_data(self) -> list | None:
return None
return json.loads(row[0])

def update_remote_signer_config(self, data: dict) -> None:
"""Updates the remote signer config in the database."""
data_string = json.dumps(data)
with self.db_connection.cursor() as cur:
cur.execute(
f'''
INSERT INTO {self.table} (vault, name, data)
VALUES (%s, %s, %s)
ON CONFLICT (vault, name) DO UPDATE SET data = %s
''',
(settings.vault, self.remote_signer_config_name, data_string, data_string),
)

def update_deposit_data(self, deposit_data: list[dict]) -> None:
"""Updates the deposit data in the database."""
data_string = json.dumps(deposit_data)
Expand Down
24 changes: 14 additions & 10 deletions src/remote_db/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def cleanup(db_url: str) -> None:

# pylint: disable=too-many-locals
async def upload_keypairs(db_url: str, b64_encrypt_key: str) -> None:
"""Generates shares for the local keypairs, updates configs in the remote DB."""
"""Uploads key-pairs to remote DB. Updates configs in the remote DB."""
encryption_key = _check_encryption_key(db_url, b64_encrypt_key)

# load and check deposit data file
Expand All @@ -79,7 +79,7 @@ async def upload_keypairs(db_url: str, b64_encrypt_key: str) -> None:
if len(keystore) == 0:
raise click.ClickException('Keystore not found.')

click.echo(f'Calculating and encrypting shares for {len(keystore)} keystores...')
click.echo(f'Encrypting {len(keystore)} keystores...')
key_records: list[RemoteDatabaseKeyPair] = []
for public_key, private_key in keystore.keys.items(): # pylint: disable=no-member
encrypted_priv_key, nonce = _encrypt_private_key(private_key, encryption_key)
Expand Down Expand Up @@ -149,7 +149,7 @@ def setup_validator(
output_dir: Path,
) -> None:
"""Generate validator configs for Lighthouse, Teku and Prysm clients."""
keypairs = KeyPairsCrud(db_url=db_url).get_keypairs(has_parent_public_key=False)
keypairs = KeyPairsCrud(db_url=db_url).get_keypairs()
if not keypairs:
raise click.ClickException('No keypairs found in the remote db.')

Expand Down Expand Up @@ -234,20 +234,24 @@ def _check_encryption_key(db_url: str, b64_encrypt_key: str) -> bytes:
encryption_key = base64.b64decode(b64_encrypt_key)
if len(encryption_key) != CIPHER_KEY_LENGTH:
raise click.ClickException('Invalid encryption key length.')
except Exception as exc:
raise click.ClickException('Invalid encryption key.') from exc

keypair = KeyPairsCrud(db_url=db_url).get_first_keypair()
if keypair is None:
return encryption_key
keypair = KeyPairsCrud(db_url=db_url).get_first_keypair()
if keypair is None:
return encryption_key

try:
decrypted_private_key = _decrypt_private_key(
private_key=Web3.to_bytes(hexstr=keypair.private_key),
encryption_key=encryption_key,
nonce=Web3.to_bytes(hexstr=keypair.nonce),
)
if bls.SkToPk(decrypted_private_key) != Web3.to_bytes(hexstr=keypair.public_key):
raise click.ClickException('Failed to decrypt first private key.')
except Exception as exc:
raise click.ClickException('Invalid encryption key.') from exc
except Exception as e:
raise click.ClickException('Failed to decrypt first private key.') from e

if bls.SkToPk(decrypted_private_key) != Web3.to_bytes(hexstr=keypair.public_key):
raise click.ClickException('Failed to decrypt first private key.')

return encryption_key

Expand Down
Loading

0 comments on commit dfff5c6

Please sign in to comment.