Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a class for the chia cli context data #18919

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions chia/_tests/cmds/test_click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Uint64ParamType,
)
from chia.cmds.units import units
from chia.cmds.util import ChiaCliContext
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.bech32m import encode_puzzle_hash
from chia.util.ints import uint64
Expand Down Expand Up @@ -114,7 +115,9 @@ def test_click_amount_type() -> None:


def test_click_address_type() -> None:
context = cast(Context, FakeContext(obj={"expected_prefix": "xch"})) # this makes us not have to use a config file
context = cast(
Context, FakeContext(obj=ChiaCliContext(expected_prefix="xch").to_click())
) # this makes us not have to use a config file
std_cli_address = CliAddress(burn_ph, burn_address, AddressType.XCH)
nft_cli_address = CliAddress(burn_ph, burn_nft_addr, AddressType.DID)
# Test CliAddress (Generally is not used)
Expand Down Expand Up @@ -149,7 +152,7 @@ def test_click_address_type() -> None:

def test_click_address_type_config(root_path_populated_with_config: Path) -> None:
# set a root path in context.
context = cast(Context, FakeContext(obj={"root_path": root_path_populated_with_config}))
context = cast(Context, FakeContext(obj=ChiaCliContext(root_path=root_path_populated_with_config).to_click()))
# run test that should pass
assert AddressParamType().convert(burn_address, None, context) == CliAddress(burn_ph, burn_address, AddressType.XCH)
assert context.obj["expected_prefix"] == "xch" # validate that the prefix was set correctly
Expand Down
8 changes: 5 additions & 3 deletions chia/_tests/cmds/test_cmd_framework.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pathlib
import textwrap
from collections.abc import Sequence
from dataclasses import asdict
Expand All @@ -13,6 +14,7 @@
from chia._tests.environments.wallet import WalletTestFramework
from chia._tests.wallet.conftest import * # noqa
from chia.cmds.cmd_classes import ChiaCommand, Context, NeedsWalletRPC, chia_command, option
from chia.cmds.util import ChiaCliContext
from chia.types.blockchain_format.sized_bytes import bytes32


Expand Down Expand Up @@ -143,14 +145,14 @@ def test_context_requirement() -> None:
@click.group()
@click.pass_context
def cmd(ctx: click.Context) -> None:
ctx.obj = {"foo": "bar"}
ctx.obj = ChiaCliContext(root_path=pathlib.Path("foo", "bar")).to_click()

@chia_command(cmd, "temp_cmd", "blah")
class TempCMD:
context: Context

def run(self) -> None:
assert self.context["foo"] == "bar"
assert self.context.root_path == pathlib.Path("foo", "bar")

runner = CliRunner()
result = runner.invoke(
Expand Down Expand Up @@ -385,7 +387,7 @@ def run(self) -> None:

expected_command = TempCMD(
rpc_info=NeedsWalletRPC(
context={"root_path": wallet_environments.environments[0].node.root_path},
context=ChiaCliContext(root_path=wallet_environments.environments[0].node.root_path),
wallet_rpc_port=port,
fingerprint=fingerprint,
),
Expand Down
11 changes: 6 additions & 5 deletions chia/cmds/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
validate_beta_path,
validate_metrics_log_interval,
)
from chia.cmds.util import ChiaCliContext
from chia.util.beta_metrics import metrics_log_interval_default
from chia.util.config import lock_and_load_config, save_config

Expand All @@ -37,7 +38,7 @@ def beta_cmd() -> None:
@click.option("-i", "--interval", help="System metrics will be logged based on this interval", type=int, required=False)
@click.pass_context
def configure(ctx: click.Context, path: Optional[str], interval: Optional[int]) -> None:
root_path = ctx.obj["root_path"]
root_path = ChiaCliContext.from_click(ctx).root_path
with lock_and_load_config(root_path, "config.yaml") as config:
if "beta" not in config:
raise click.ClickException("beta test mode is not enabled, enable it first with `chia beta enable`")
Expand Down Expand Up @@ -79,7 +80,7 @@ def configure(ctx: click.Context, path: Optional[str], interval: Optional[int])
@click.option("-p", "--path", help="The beta mode root path", type=str, required=False)
@click.pass_context
def enable_cmd(ctx: click.Context, force: bool, path: Optional[str]) -> None:
root_path = ctx.obj["root_path"]
root_path = ChiaCliContext.from_click(ctx).root_path
with lock_and_load_config(root_path, "config.yaml") as config:
if config.get("beta", {}).get("enabled", False):
raise click.ClickException("beta test mode is already enabled")
Expand Down Expand Up @@ -107,7 +108,7 @@ def enable_cmd(ctx: click.Context, force: bool, path: Optional[str]) -> None:
@beta_cmd.command("disable", help="Disable beta test mode")
@click.pass_context
def disable_cmd(ctx: click.Context) -> None:
root_path = ctx.obj["root_path"]
root_path = ChiaCliContext.from_click(ctx).root_path
with lock_and_load_config(root_path, "config.yaml") as config:
if not config.get("beta", {}).get("enabled", False):
raise click.ClickException("beta test mode is not enabled")
Expand All @@ -121,7 +122,7 @@ def disable_cmd(ctx: click.Context) -> None:
@beta_cmd.command("prepare_submission", help="Prepare the collected log data for submission")
@click.pass_context
def prepare_submission_cmd(ctx: click.Context) -> None:
with lock_and_load_config(ctx.obj["root_path"], "config.yaml") as config:
with lock_and_load_config(ChiaCliContext.from_click(ctx).root_path, "config.yaml") as config:
beta_root_path = config.get("beta", {}).get("path", None)
if beta_root_path is None:
raise click.ClickException("beta test mode not enabled. Run `chia beta enable` first.")
Expand Down Expand Up @@ -173,7 +174,7 @@ def add_files(paths: list[Path]) -> int:
@beta_cmd.command("status", help="Show the current beta configuration")
@click.pass_context
def status(ctx: click.Context) -> None:
with lock_and_load_config(ctx.obj["root_path"], "config.yaml") as config:
with lock_and_load_config(ChiaCliContext.from_click(ctx).root_path, "config.yaml") as config:
beta_config = config.get("beta")
if beta_config is None:
raise click.ClickException("beta test mode is not enabled, enable it first with `chia beta enable`")
Expand Down
6 changes: 3 additions & 3 deletions chia/cmds/check_wallet_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ async def scan(self, db_path: Path) -> int:
return len(errors)


async def scan(root_path: str, db_path: Optional[str] = None, *, verbose: bool = False) -> None:
async def scan(root_path: Path, db_path: Optional[str] = None, *, verbose: bool = False) -> None:
if db_path is None:
wallet_db_path = Path(root_path) / "wallet" / "db"
wallet_db_path = root_path / "wallet" / "db"
wallet_db_paths = list(wallet_db_path.glob("blockchain_wallet_*.sqlite"))
else:
wallet_db_paths = [Path(db_path)]
Expand All @@ -417,4 +417,4 @@ async def scan(root_path: str, db_path: Optional[str] = None, *, verbose: bool =

if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(scan("", sys.argv[1]))
loop.run_until_complete(scan(Path(""), sys.argv[1]))
5 changes: 3 additions & 2 deletions chia/cmds/chia.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from chia.cmds.show import show_cmd
from chia.cmds.start import start_cmd
from chia.cmds.stop import stop_cmd
from chia.cmds.util import ChiaCliContext
from chia.cmds.wallet import wallet_cmd
from chia.util.default_root import DEFAULT_KEYS_ROOT_PATH, DEFAULT_ROOT_PATH
from chia.util.errors import KeychainCurrentPassphraseIsInvalid
Expand Down Expand Up @@ -55,7 +56,7 @@ def cli(
from pathlib import Path

ctx.ensure_object(dict)
ctx.obj["root_path"] = Path(root_path)
ctx.obj.update(ChiaCliContext(root_path=Path(root_path)).to_click())

# keys_root_path and passphrase_file will be None if the passphrase options have been
# scrubbed from the CLI options
Expand Down Expand Up @@ -107,7 +108,7 @@ def run_daemon_cmd(ctx: click.Context, wait_for_unlock: bool) -> None:

wait_for_unlock = wait_for_unlock and Keychain.is_keyring_locked()

asyncio.run(async_run_daemon(ctx.obj["root_path"], wait_for_unlock=wait_for_unlock))
asyncio.run(async_run_daemon(ChiaCliContext.from_click(ctx).root_path, wait_for_unlock=wait_for_unlock))


cli.add_command(keys_cmd)
Expand Down
9 changes: 5 additions & 4 deletions chia/cmds/cmd_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing_extensions import dataclass_transform

from chia.cmds.cmds_util import get_wallet_client
from chia.cmds.util import ChiaCliContext
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
Expand Down Expand Up @@ -118,7 +119,7 @@ def apply_decorators(self, cmd: SyncCmd) -> SyncCmd:

def strip_click_context(func: SyncCmd) -> SyncCmd:
def _inner(ctx: click.Context, **kwargs: Any) -> None:
context: dict[str, Any] = ctx.obj if ctx.obj is not None else {}
context: Context = ChiaCliContext.from_click(ctx)
func(context=context, **kwargs)

return _inner
Expand Down Expand Up @@ -256,7 +257,7 @@ def command_helper(cls: type[Any]) -> type[Any]:
return new_cls


Context = dict[str, Any]
Context = ChiaCliContext


@dataclass(frozen=True)
Expand All @@ -268,7 +269,7 @@ class WalletClientInfo:

@command_helper
class NeedsWalletRPC:
context: Context = field(default_factory=dict)
context: Context = field(default_factory=ChiaCliContext)
client_info: Optional[WalletClientInfo] = None
wallet_rpc_port: Optional[int] = option(
"-wp",
Expand All @@ -294,7 +295,7 @@ async def wallet_rpc(self, **kwargs: Any) -> AsyncIterator[WalletClientInfo]:
yield self.client_info
else:
if "root_path" not in kwargs:
kwargs["root_path"] = self.context["root_path"]
kwargs["root_path"] = self.context.root_path
async with get_wallet_client(self.wallet_rpc_port, self.fingerprint, **kwargs) as (
wallet_client,
fp,
Expand Down
3 changes: 2 additions & 1 deletion chia/cmds/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import click
import yaml

from chia.cmds.util import ChiaCliContext
from chia.server.outbound_message import NodeType
from chia.util.config import (
initial_config_file,
Expand Down Expand Up @@ -313,7 +314,7 @@ def configure_cmd(
seeder_nameserver: str,
) -> None:
configure(
ctx.obj["root_path"],
ChiaCliContext.from_click(ctx).root_path,
set_farmer_peer,
set_node_introducer,
set_fullnode_port,
Expand Down
7 changes: 4 additions & 3 deletions chia/cmds/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from chia.cmds.db_backup_func import db_backup_func
from chia.cmds.db_upgrade_func import db_upgrade_func
from chia.cmds.db_validate_func import db_validate_func
from chia.cmds.util import ChiaCliContext


@click.group("db", help="Manage the blockchain database")
Expand Down Expand Up @@ -41,7 +42,7 @@ def db_upgrade_cmd(
) -> None:
try:
db_upgrade_func(
Path(ctx.obj["root_path"]),
ChiaCliContext.from_click(ctx).root_path,
None if in_db_path is None else Path(in_db_path),
None if out_db_path is None else Path(out_db_path),
no_update_config=no_update_config,
Expand All @@ -63,7 +64,7 @@ def db_upgrade_cmd(
def db_validate_cmd(ctx: click.Context, in_db_path: Optional[str], validate_blocks: bool) -> None:
try:
db_validate_func(
Path(ctx.obj["root_path"]),
ChiaCliContext.from_click(ctx).root_path,
None if in_db_path is None else Path(in_db_path),
validate_blocks=validate_blocks,
)
Expand All @@ -78,7 +79,7 @@ def db_validate_cmd(ctx: click.Context, in_db_path: Optional[str], validate_bloc
def db_backup_cmd(ctx: click.Context, db_backup_file: Optional[str], no_indexes: bool) -> None:
try:
db_backup_func(
Path(ctx.obj["root_path"]),
ChiaCliContext.from_click(ctx).root_path,
None if db_backup_file is None else Path(db_backup_file),
no_indexes=no_indexes,
)
Expand Down
4 changes: 3 additions & 1 deletion chia/cmds/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import click

from chia.cmds.util import ChiaCliContext


@click.command("init", help="Create or migrate the configuration")
@click.option(
Expand Down Expand Up @@ -55,7 +57,7 @@ def init_cmd(

init(
Path(create_certs) if create_certs is not None else None,
ctx.obj["root_path"],
ChiaCliContext.from_click(ctx).root_path,
fix_ssl_permissions,
testnet,
v1_db,
Expand Down
30 changes: 22 additions & 8 deletions chia/cmds/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from chia_rs import PrivateKey

from chia.cmds import options
from chia.cmds.util import ChiaCliContext


@click.group("keys", help="Manage your keys")
@click.pass_context
def keys_cmd(ctx: click.Context) -> None:
"""Create, delete, view and use your key pairs"""
from pathlib import Path

root_path: Path = ctx.obj["root_path"]
root_path = ChiaCliContext.from_click(ctx).root_path
if not root_path.is_dir():
raise RuntimeError("Please initialize (or migrate) your config directory with chia init")

Expand All @@ -34,7 +34,7 @@ def generate_cmd(ctx: click.Context, label: Optional[str]) -> None:
from .keys_funcs import generate_and_add

generate_and_add(label)
check_keys(ctx.obj["root_path"])
check_keys(ChiaCliContext.from_click(ctx).root_path)


@keys_cmd.command("show", help="Displays all the keys in keychain or the key with the given fingerprint")
Expand Down Expand Up @@ -77,7 +77,14 @@ def show_cmd(
) -> None:
from .keys_funcs import show_keys

show_keys(ctx.obj["root_path"], show_mnemonic_seed, non_observer_derivation, json, fingerprint, bech32m_prefix)
show_keys(
ChiaCliContext.from_click(ctx).root_path,
show_mnemonic_seed,
non_observer_derivation,
json,
fingerprint,
bech32m_prefix,
)


@keys_cmd.command("add", help="Add a private key by mnemonic or public key as hex")
Expand Down Expand Up @@ -109,7 +116,7 @@ def add_cmd(ctx: click.Context, filename: str, label: Optional[str]) -> None:
mnemonic_or_pk = Path(filename).read_text().rstrip()

query_and_add_key_info(mnemonic_or_pk, label)
check_keys(ctx.obj["root_path"])
check_keys(ChiaCliContext.from_click(ctx).root_path)


@keys_cmd.group("label", help="Manage your key labels")
Expand Down Expand Up @@ -155,7 +162,7 @@ def delete_cmd(ctx: click.Context, fingerprint: int) -> None:
from .keys_funcs import delete

delete(fingerprint)
check_keys(ctx.obj["root_path"])
check_keys(ChiaCliContext.from_click(ctx).root_path)


@keys_cmd.command("delete_all", help="Delete all private keys in keychain")
Expand Down Expand Up @@ -343,7 +350,7 @@ def search_cmd(
print("Could not resolve private key from fingerprint/mnemonic file")

found: bool = search_derive(
ctx.obj["root_path"],
ChiaCliContext.from_click(ctx).root_path,
fingerprint,
search_terms,
limit,
Expand Down Expand Up @@ -419,7 +426,14 @@ def wallet_address_cmd(
return

derive_wallet_address(
ctx.obj["root_path"], fingerprint, index, count, prefix, non_observer_derivation, show_hd_path, sk
ChiaCliContext.from_click(ctx).root_path,
fingerprint,
index,
count,
prefix,
non_observer_derivation,
show_hd_path,
sk,
)


Expand Down
7 changes: 5 additions & 2 deletions chia/cmds/param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import click

from chia.cmds.units import units
from chia.cmds.util import ChiaCliContext
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.bech32m import bech32_decode, decode_puzzle_hash
from chia.util.config import load_config, selected_network_address_prefix
Expand Down Expand Up @@ -173,9 +174,11 @@ def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[cl
hrp, _b32data = bech32_decode(value)
if hrp in {"xch", "txch"}: # I hate having to load the config here
addr_type: AddressType = AddressType.XCH
expected_prefix = ctx.obj.get("expected_prefix") if ctx else None # attempt to get cached prefix
expected_prefix = (
ChiaCliContext.from_click(ctx).expected_prefix if ctx else None
) # attempt to get cached prefix
if expected_prefix is None:
root_path = ctx.obj["root_path"] if ctx is not None else DEFAULT_ROOT_PATH
root_path = ChiaCliContext.from_click(ctx).root_path if ctx is not None else DEFAULT_ROOT_PATH
config = load_config(root_path, "config.yaml")
expected_prefix = selected_network_address_prefix(config)

Expand Down
Loading
Loading