diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 00d5107b..f6bbbee1 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -71,3 +71,4 @@ jobs: export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python # sh -ex ./snet/cli/test/utils/run_all_functional.sh python3 ./snet/cli/test/functional_tests/test_entry_point.py + python3 ./snet/cli/test/functional_tests/func_tests.py diff --git a/requirements.txt b/requirements.txt index 5e61ba38..82a8e571 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ trezor==0.13.8 ledgerblue==0.1.48 snet.contracts==0.1.1 lighthouseweb3==0.1.4 +cryptography==43.0.3 \ No newline at end of file diff --git a/snet/cli/arguments.py b/snet/cli/arguments.py index 02f0eb8d..1bc5d45f 100644 --- a/snet/cli/arguments.py +++ b/snet/cli/arguments.py @@ -125,7 +125,8 @@ def add_identity_options(parser, config): p.set_defaults(fn="list") p = subparsers.add_parser("create", - help="Create a new identity") + help="Create a new identity. For 'mnemonic' and 'key' identity_type, " + "secret encryption is enabled by default.") p.set_defaults(fn="create") p.add_argument("identity_name", help="Name of identity to create", @@ -135,6 +136,10 @@ def add_identity_options(parser, config): help="Type of identity to create from {}".format( get_identity_types()), metavar="IDENTITY_TYPE") + p.add_argument("-de", "--do-not-encrypt", + default=False, + action="store_true", + help="Do not encrypt the identity's private key or mnemonic. For 'key' and 'mnemonic' identity_type.") p.add_argument("--mnemonic", help="BIP39 mnemonic for 'mnemonic' identity_type") p.add_argument("--private-key", @@ -451,7 +456,8 @@ def add_contract_function_options(parser, contract_name): fns.append({ "name": fn["name"], "named_inputs": [(i["name"], i["type"]) for i in fn["inputs"] if i["name"] != ""], - "positional_inputs": [i["type"] for i in fn["inputs"] if i["name"] == ""] + "positional_inputs": [i["type"] for i in fn["inputs"] if i["name"] == ""], + "stateMutability": fn["stateMutability"] }) if len(fns) > 0: @@ -462,7 +468,10 @@ def add_contract_function_options(parser, contract_name): for fn in fns: fn_p = subparsers.add_parser( fn["name"], help="{} function".format(fn["name"])) - fn_p.set_defaults(fn="call") + if fn["stateMutability"] == "view": + fn_p.set_defaults(fn="call") + else: + fn_p.set_defaults(fn="transact") fn_p.set_defaults(contract_function=fn["name"]) for i in fn["positional_inputs"]: fn_p.add_argument(i, @@ -473,12 +482,8 @@ def add_contract_function_options(parser, contract_name): fn_p.add_argument("contract_named_input_{}".format(i[0]), type=type_converter(i[1]), metavar="{}_{}".format(i[0].lstrip("_"), i[1].upper())) - fn_p.add_argument("--transact", - action="store_const", - const="transact", - dest="fn", - help="Invoke contract function as transaction") - add_transaction_arguments(fn_p) + if fn["stateMutability"] != "view": + add_transaction_arguments(fn_p) def add_contract_identity_arguments(parser, names_and_destinations=(("", "at"),)): diff --git a/snet/cli/commands/commands.py b/snet/cli/commands/commands.py index df0d38cd..aa00f3c1 100644 --- a/snet/cli/commands/commands.py +++ b/snet/cli/commands/commands.py @@ -11,6 +11,7 @@ from lighthouseweb3 import Lighthouse import yaml import web3 +from cryptography.fernet import InvalidToken from snet.contracts import get_contract_def from snet.cli.contract import Contract @@ -18,7 +19,7 @@ MnemonicIdentityProvider, RpcIdentityProvider, TrezorIdentityProvider, get_kws_for_identity_type from snet.cli.metadata.organization import OrganizationMetadata, PaymentStorageClient, Payment, Group from snet.cli.utils.config import get_contract_address, get_field_from_args_or_session, \ - read_default_contract_address + read_default_contract_address, decrypt_secret from snet.cli.utils.ipfs_utils import get_from_ipfs_and_checkhash, \ hash_to_bytesuri, publish_file_in_ipfs, publish_file_in_filecoin from snet.cli.utils.utils import DefaultAttributeObject, get_web3, is_valid_url, serializable, type_converter, \ @@ -174,6 +175,32 @@ def get_identity(self): if identity_type == "keystore": return KeyStoreIdentityProvider(self.w3, self.config.get_session_field("keystore_path")) + def check_ident(self): + identity_type = self.config.get_session_field("identity_type") + if get_kws_for_identity_type(identity_type)[0][1] and not self.ident.private_key: + if identity_type == "key": + secret = self.config.get_session_field("private_key") + else: + secret = self.config.get_session_field("mnemonic") + decrypted_secret = self._get_decrypted_secret(secret) + self.ident.set_secret(decrypted_secret) + + def _get_decrypted_secret(self, secret): + decrypted_secret = None + try: + password = getpass.getpass("Password: ") + decrypted_secret = decrypt_secret(secret, password) + except InvalidToken: + self._printout("Wrong password! Try again") + if not decrypted_secret: + try: + password = getpass.getpass("Password: ") + decrypted_secret = decrypt_secret(secret, password) + except InvalidToken: + self._printerr("Wrong password! Operation failed.") + exit(1) + return decrypted_secret + def get_contract_argser(self, contract_address, contract_function, contract_def, **kwargs): def f(*positional_inputs, **named_inputs): args_dict = self.args.__dict__.copy() @@ -243,7 +270,19 @@ def create(self): if self.args.network: identity["network"] = self.args.network identity["default_wallet_index"] = self.args.wallet_index - self.config.add_identity(identity_name, identity, self.out_f) + + password = None + if not self.args.do_not_encrypt and get_kws_for_identity_type(identity_type)[0][1]: + self._printout("For 'mnemonic' and 'key' identity_type, secret encryption is enabled by default, " + "so you need to come up with a password that you then need to enter on every transaction. " + "To disable encryption, use the '-de' or '--do-not-encrypt' argument.") + password = getpass.getpass("Password: ") + self._ensure(password is not None, "Password cannot be empty") + pwd_confirm = getpass.getpass("Confirm password: ") + self._ensure(password == pwd_confirm, "Passwords do not match") + + self.config.add_identity(identity_name, identity, self.out_f, password) + def list(self): for identity_section in filter(lambda x: x.startswith("identity."), self.config.sections()): @@ -302,11 +341,11 @@ def unset(self): class SessionShowCommand(BlockchainCommand): + def show(self): rez = self.config.session_to_dict() key = "network.%s" % rez['session']['network'] self.populate_contract_address(rez, key) - # we don't want to who private_key and mnemonic for d in rez.values(): d.pop("private_key", None) @@ -348,6 +387,7 @@ def call(self): return result def transact(self): + self.check_ident() contract_address = get_contract_address(self, self.args.contract_name, "--at is required to specify target contract address") @@ -402,7 +442,8 @@ def add_group(self): raise Exception(f"Invalid {endpoint} endpoint passed") payment_storage_client = PaymentStorageClient(self.args.payment_channel_connection_timeout, - self.args.payment_channel_request_timeout, self.args.endpoints) + self.args.payment_channel_request_timeout, + self.args.endpoints) payment = Payment(self.args.payment_address, self.args.payment_expiration_threshold, self.args.payment_channel_storage_type, payment_storage_client) group_id = base64.b64encode(secrets.token_bytes(32)) @@ -424,8 +465,7 @@ def remove_group(self): raise e existing_groups = org_metadata.groups - updated_groups = [ - group for group in existing_groups if not group_id == group.group_id] + updated_groups = [group for group in existing_groups if not group_id == group.group_id] org_metadata.groups = updated_groups org_metadata.save_pretty(metadata_file) @@ -437,17 +477,13 @@ def set_changed_values_for_group(self, group): if self.args.payment_address: group.update_payment_address(self.args.payment_address) if self.args.payment_expiration_threshold: - group.update_payment_expiration_threshold( - self.args.payment_expiration_threshold) + group.update_payment_expiration_threshold(self.args.payment_expiration_threshold) if self.args.payment_channel_storage_type: - group.update_payment_channel_storage_type( - self.args.payment_channel_storage_type) + group.update_payment_channel_storage_type(self.args.payment_channel_storage_type) if self.args.payment_channel_connection_timeout: - group.update_connection_timeout( - self.args.payment_channel_connection_timeout) + group.update_connection_timeout(self.args.payment_channel_connection_timeout) if self.args.payment_channel_request_timeout: - group.update_request_timeout( - self.args.payment_channel_request_timeout) + group.update_request_timeout(self.args.payment_channel_request_timeout) def update_group(self): group_id = self.args.group_id @@ -667,7 +703,6 @@ def get_path(err): return {"status": 0, "msg": "Organization metadata is valid and ready to publish."} def create(self): - self._metadata_validate() metadata_file = self.args.metadata_file diff --git a/snet/cli/commands/mpe_account.py b/snet/cli/commands/mpe_account.py index e01f35b6..c5d29ea5 100644 --- a/snet/cli/commands/mpe_account.py +++ b/snet/cli/commands/mpe_account.py @@ -5,10 +5,12 @@ class MPEAccountCommand(BlockchainCommand): def print_account(self): + self.check_ident() self._printout(self.ident.address) def print_agix_and_mpe_balances(self): """ Print balance of ETH, AGIX, and MPE wallet """ + self.check_ident() if self.args.account: account = self.args.account else: @@ -24,6 +26,7 @@ def print_agix_and_mpe_balances(self): self._printout(" MPE: %s"%cogs2stragix(mpe_cogs)) def deposit_to_mpe(self): + self.check_ident() amount = self.args.amount mpe_address = self.get_mpe_address() @@ -33,7 +36,9 @@ def deposit_to_mpe(self): self.transact_contract_command("MultiPartyEscrow", "deposit", [amount]) def withdraw_from_mpe(self): + self.check_ident() self.transact_contract_command("MultiPartyEscrow", "withdraw", [self.args.amount]) def transfer_in_mpe(self): + self.check_ident() self.transact_contract_command("MultiPartyEscrow", "transfer", [self.args.receiver, self.args.amount]) diff --git a/snet/cli/commands/mpe_channel.py b/snet/cli/commands/mpe_channel.py index 507534a6..53fae089 100644 --- a/snet/cli/commands/mpe_channel.py +++ b/snet/cli/commands/mpe_channel.py @@ -8,7 +8,6 @@ from pathlib import Path from eth_abi.codec import ABICodec -from web3._utils.encoding import pad_hex from web3._utils.events import get_event_data from snet.contracts import get_contract_def, get_contract_deployment_block @@ -504,8 +503,10 @@ def _print_channels(self, channels, filters: list[str] = None): def get_address_from_arg_or_ident(self, arg): if arg: return arg + self.check_ident() return self.ident.address + def print_channels_filter_sender(self): # we don't need to return other channel fields if we only need channel_id or if we'll sync channels state return_only_id = self.args.only_id or not self.args.do_not_sync diff --git a/snet/cli/commands/mpe_client.py b/snet/cli/commands/mpe_client.py index 64b9de66..49bb90ad 100644 --- a/snet/cli/commands/mpe_client.py +++ b/snet/cli/commands/mpe_client.py @@ -160,7 +160,7 @@ def _get_endpoint_from_metadata_or_args(self, metadata): return endpoints[0] def call_server_lowlevel(self): - + self.check_ident() self._init_or_update_registered_org_if_needed() self._init_or_update_registered_service_if_needed() @@ -263,6 +263,8 @@ def _get_channel_state_statelessly(self, grpc_channel, channel_id): return server["current_nonce"], server["current_signed_amount"], unspent_amount def print_channel_state_statelessly(self): + self.check_ident() + grpc_channel = open_grpc_channel(self.args.endpoint) current_nonce, current_amount, unspent_amount = self._get_channel_state_statelessly( @@ -308,6 +310,7 @@ def call_server_statelessly_with_params(self, params, group_name): return self._call_server_via_grpc_channel(grpc_channel, channel_id, server_state["current_nonce"], server_state["current_signed_amount"] + price, params, service_metadata) def call_server_statelessly(self): + self.check_ident() group_name = self.args.group_name params = self._get_call_params() response = self.call_server_statelessly_with_params(params, group_name) diff --git a/snet/cli/commands/mpe_service.py b/snet/cli/commands/mpe_service.py index 61e3ea89..442b6a45 100644 --- a/snet/cli/commands/mpe_service.py +++ b/snet/cli/commands/mpe_service.py @@ -632,7 +632,6 @@ def extract_service_api_from_metadata(self): service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash") download_and_safe_extract_proto(service_api_source, self.args.protodir, self._get_ipfs_client()) - def extract_service_api_from_registry(self): metadata = self._get_service_metadata_from_registry() service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash") diff --git a/snet/cli/commands/mpe_treasurer.py b/snet/cli/commands/mpe_treasurer.py index db1d959b..a0855d46 100644 --- a/snet/cli/commands/mpe_treasurer.py +++ b/snet/cli/commands/mpe_treasurer.py @@ -150,10 +150,12 @@ def _claim_in_progress_and_claim_channels(self, grpc_channel, channels): self._blockchain_claim(payments) def claim_channels(self): + self.check_ident() grpc_channel = open_grpc_channel(self.args.endpoint) self._claim_in_progress_and_claim_channels(grpc_channel, self.args.channels) def claim_all_channels(self): + self.check_ident() grpc_channel = open_grpc_channel(self.args.endpoint) # we take list of all channels unclaimed_payments = self._call_GetListUnclaimed(grpc_channel) @@ -161,6 +163,7 @@ def claim_all_channels(self): self._claim_in_progress_and_claim_channels(grpc_channel, channels) def claim_almost_expired_channels(self): + self.check_ident() grpc_channel = open_grpc_channel(self.args.endpoint) # we take list of all channels unclaimed_payments = self._call_GetListUnclaimed(grpc_channel) diff --git a/snet/cli/config.py b/snet/cli/config.py index 140fcff7..9c62382c 100644 --- a/snet/cli/config.py +++ b/snet/cli/config.py @@ -2,6 +2,8 @@ from pathlib import Path import sys +from snet.cli.utils.config import encrypt_secret + default_snet_folder = Path("~").expanduser().joinpath(".snet") DEFAULT_NETWORK = "sepolia" @@ -138,12 +140,19 @@ def set_network_field(self, network, key, value): self._get_network_section(network)[key] = str(value) self._persist() - def add_identity(self, identity_name, identity, out_f=sys.stdout): + def add_identity(self, identity_name, identity, out_f=sys.stdout, password=None): identity_section = "identity.%s" % identity_name if identity_section in self: raise Exception("Identity section %s already exists in config" % identity_section) if "network" in identity and identity["network"] not in self.get_all_networks_names(): raise Exception("Network %s is not in config" % identity["network"]) + + if password: + if "mnemonic" in identity: + identity["mnemonic"] = encrypt_secret(identity["mnemonic"], password) + elif "private_key" in identity: + identity["private_key"] = encrypt_secret(identity["private_key"], password) + self[identity_section] = identity self._persist() # switch to it, if it was the first identity diff --git a/snet/cli/identity.py b/snet/cli/identity.py index 3609daba..0eb4cd4e 100644 --- a/snet/cli/identity.py +++ b/snet/cli/identity.py @@ -38,6 +38,13 @@ def sign_message_after_solidity_keccak(self, message): class KeyIdentityProvider(IdentityProvider): def __init__(self, w3, private_key): self.w3 = w3 + if private_key.startswith("::"): + self.private_key = None + self.address = None + return + self.set_secret(private_key) + + def set_secret(self, private_key): self.private_key = normalize_private_key(private_key) self.address = get_address_from_private(self.private_key) @@ -109,8 +116,16 @@ def sign_message_after_solidity_keccak(self, message): class MnemonicIdentityProvider(IdentityProvider): def __init__(self, w3, mnemonic, index): self.w3 = w3 + self.index = index + if mnemonic.startswith("::"): + self.private_key = None + self.address = None + return + self.set_secret(mnemonic) + + def set_secret(self, mnemonic): Account.enable_unaudited_hdwallet_features() - account = Account.from_mnemonic(mnemonic, account_path=f"m/44'/60'/0'/0/{index}") + account = Account.from_mnemonic(mnemonic, account_path=f"m/44'/60'/0'/0/{self.index}") self.private_key = account.key.hex() self.address = account.address diff --git a/snet/cli/test/functional_tests/func_tests.py b/snet/cli/test/functional_tests/func_tests.py new file mode 100644 index 00000000..30a35761 --- /dev/null +++ b/snet/cli/test/functional_tests/func_tests.py @@ -0,0 +1,133 @@ +import warnings +import argcomplete +import unittest +import unittest.mock as mock +import shutil +import os + +from snet.cli.commands.commands import BlockchainCommand + +with warnings.catch_warnings(): + # Suppress the eth-typing package`s warnings related to some new networks + warnings.filterwarnings("ignore", "Network .* does not have a valid ChainId. eth-typing should be " + "updated with the latest networks.", UserWarning) + from snet.cli import arguments + +from snet.cli.config import Config + + +class StringOutput: + def __init__(self): + self.text = "" + + def write(self, text): + self.text += text + + +def execute(args_list, parser, conf): + try: + argv = args_list + try: + args = parser.parse_args(argv) + except TypeError: + args = parser.parse_args(argv + ["-h"]) + f = StringOutput() + getattr(args.cmd(conf, args, out_f = f), args.fn)() + return f.text + except Exception as e: + raise + +class BaseTest(unittest.TestCase): + def setUp(self): + self.conf = Config() + self.parser = arguments.get_root_parser(self.conf) + argcomplete.autocomplete(self.parser) + + +class TestCommands(BaseTest): + def test_balance_output(self): + result = execute(["account", "balance"], self.parser, self.conf) + assert len(result.split("\n")) >= 4 + + def test_balance_address(self): + result = execute(["account", "balance"], self.parser, self.conf) + assert result.split("\n")[0].split()[1] == "0xe5D1fA424DE4689F9d2687353b75D7a8987900fD" + +class TestDepositWithdraw(BaseTest): + def setUp(self): + super().setUp() + self.balance_1: int + self.balance_2: int + self.amount = 0.1 + + def test_deposit(self): + result = execute(["account", "balance"], self.parser, self.conf) + self.balance_1 = float(result.split("\n")[3].split()[1]) + execute(["account", "deposit", f"{self.amount}", "-y", "-q"], self.parser, self.conf) + result = execute(["account", "balance"], self.parser, self.conf) + self.balance_2 = float(result.split("\n")[3].split()[1]) + assert self.balance_2 == self.balance_1 + self.amount + + def test_withdraw(self): + result = execute(["account", "balance"], self.parser, self.conf) + self.balance_1 = float(result.split("\n")[3].split()[1]) + execute(["account", "withdraw", f"{self.amount}", "-y", "-q"], self.parser, self.conf) + result = execute(["account", "balance"], self.parser, self.conf) + self.balance_2 = float(result.split("\n")[3].split()[1]) + assert self.balance_2 == self.balance_1 - self.amount + + +class TestGenerateLibrary(BaseTest): + def setUp(self): + super().setUp() + self.path = './temp_files' + self.org_id = '26072b8b6a0e448180f8c0e702ab6d2f' + self.service_id = 'Exampleservice' + + def test_generate(self): + execute(["sdk", "generate-client-library", self.org_id, self.service_id, self.path], self.parser, self.conf) + assert os.path.exists(f'{self.path}/{self.org_id}/{self.service_id}/python/') + + def tearDown(self): + shutil.rmtree(self.path) + + +class TestEncryptionKey(BaseTest): + def setUp(self): + super().setUp() + self.key = "1234567890123456789012345678901234567890123456789012345678901234" + self.password = "some_pass" + self.name = "some_name" + self.default_name = "default_name" + result = execute(["identity", "list"], self.parser, self.conf) + if self.default_name not in result: + execute(["identity", "create", self.default_name, "key", "--private-key", self.key, "-de"], + self.parser, + self.conf) + + def test_1_create_identity_with_encryption_key(self): + with mock.patch('getpass.getpass', return_value=self.password): + execute(["identity", "create", self.name, "key", "--private-key", self.key], + self.parser, + self.conf) + result = execute(["identity", "list"], self.parser, self.conf) + assert self.name in result + + def test_2_get_encryption_key(self): + with mock.patch('getpass.getpass', return_value=self.password): + execute(["identity", self.name], self.parser, self.conf) + cmd = BlockchainCommand(self.conf, self.parser.parse_args(['session'])) + enc_key = cmd.config.get_session_field("private_key") + res_key = cmd._get_decrypted_secret(enc_key) + assert res_key == self.key + + def test_3_delete_identity(self): + with mock.patch('getpass.getpass', return_value=self.password): + execute(["identity", self.default_name], self.parser, self.conf) + execute(["identity", "delete", self.name], self.parser, self.conf) + result = execute(["identity", "list"], self.parser, self.conf) + assert self.name not in result + + +if __name__ == "__main__": + unittest.main() diff --git a/snet/cli/utils/config.py b/snet/cli/utils/config.py index 4543ebfa..2c3f0ad0 100644 --- a/snet/cli/utils/config.py +++ b/snet/cli/utils/config.py @@ -1,5 +1,12 @@ from snet.contracts import get_contract_def +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from base64 import urlsafe_b64encode, urlsafe_b64decode +import os + def get_contract_address(cmd, contract_name, error_message=None): """ @@ -60,3 +67,34 @@ def get_field_from_args_or_session(config, args, field_name): return rez raise Exception("Fail to get default_%s from config, should specify %s via --%s parameter" % ( field_name, field_name, field_name.replace("_", "-"))) + + +def encrypt_secret(secret, password): + salt = os.urandom(16) + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + backend=default_backend() + ) + key = urlsafe_b64encode(kdf.derive(password.encode())) + cipher_suite = Fernet(key) + encrypted_secret = cipher_suite.encrypt(secret.encode()) + salt + return '::' + str(urlsafe_b64encode(encrypted_secret))[2:-1] + +def decrypt_secret(secret, password): + secret = urlsafe_b64decode(secret[2:]) + salt = secret[-16:] + secret = secret[:-16] + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + backend=default_backend() + ) + key = urlsafe_b64encode(kdf.derive(password.encode())) + cipher_suite = Fernet(key) + decrypted_secret = cipher_suite.decrypt(secret).decode() + return decrypted_secret diff --git a/version.py b/version.py index 55e47090..3d67cd6b 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -__version__ = "2.3.0" +__version__ = "2.4.0"